From 5efe8a202010fc9b6660be14ee6f2ce6e8d3df11 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Wed, 11 Feb 2026 11:41:13 -0600 Subject: [PATCH 1/3] feat(auth): add OAuth and token-based login for OpenAI and Anthropic Add `picoclaw auth` CLI command supporting: - OpenAI OAuth2 (PKCE + browser callback or device code flow) - Anthropic paste-token flow - Token storage at ~/.picoclaw/auth.json with 0600 permissions - Auto-refresh for expired OAuth tokens in provider Closes #18 Co-Authored-By: Claude Opus 4.6 --- cmd/picoclaw/main.go | 237 ++++++++++++++++++++++ pkg/auth/oauth.go | 358 +++++++++++++++++++++++++++++++++ pkg/auth/oauth_test.go | 199 ++++++++++++++++++ pkg/auth/pkce.go | 29 +++ pkg/auth/pkce_test.go | 51 +++++ pkg/auth/store.go | 112 +++++++++++ pkg/auth/store_test.go | 189 +++++++++++++++++ pkg/auth/token.go | 43 ++++ pkg/config/config.go | 5 +- pkg/providers/http_provider.go | 82 +++++++- 10 files changed, 1295 insertions(+), 10 deletions(-) create mode 100644 pkg/auth/oauth.go create mode 100644 pkg/auth/oauth_test.go create mode 100644 pkg/auth/pkce.go create mode 100644 pkg/auth/pkce_test.go create mode 100644 pkg/auth/store.go create mode 100644 pkg/auth/store_test.go create mode 100644 pkg/auth/token.go diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index c14ec58..e1128fe 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -19,6 +19,7 @@ import ( "github.com/chzyer/readline" "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" @@ -85,6 +86,8 @@ func main() { gatewayCmd() case "status": statusCmd() + case "auth": + authCmd() case "cron": cronCmd() case "skills": @@ -152,6 +155,7 @@ func printHelp() { fmt.Println("Commands:") fmt.Println(" onboard Initialize picoclaw configuration and workspace") fmt.Println(" agent Interact with the agent directly") + fmt.Println(" auth Manage authentication (login, logout, status)") fmt.Println(" gateway Start picoclaw gateway") fmt.Println(" status Show picoclaw status") fmt.Println(" cron Manage scheduled tasks") @@ -682,6 +686,239 @@ func statusCmd() { } else { fmt.Println("vLLM/Local: not set") } + + store, _ := auth.LoadStore() + if store != nil && len(store.Credentials) > 0 { + fmt.Println("\nOAuth/Token Auth:") + for provider, cred := range store.Credentials { + status := "authenticated" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status) + } + } + } +} + +func authCmd() { + if len(os.Args) < 3 { + authHelp() + return + } + + switch os.Args[2] { + case "login": + authLoginCmd() + case "logout": + authLogoutCmd() + case "status": + authStatusCmd() + default: + fmt.Printf("Unknown auth command: %s\n", os.Args[2]) + authHelp() + } +} + +func authHelp() { + fmt.Println("\nAuth commands:") + fmt.Println(" login Login via OAuth or paste token") + fmt.Println(" logout Remove stored credentials") + fmt.Println(" status Show current auth status") + fmt.Println() + fmt.Println("Login options:") + fmt.Println(" --provider Provider to login with (openai, anthropic)") + fmt.Println(" --device-code Use device code flow (for headless environments)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw auth login --provider openai") + fmt.Println(" picoclaw auth login --provider openai --device-code") + fmt.Println(" picoclaw auth login --provider anthropic") + fmt.Println(" picoclaw auth logout --provider openai") + fmt.Println(" picoclaw auth status") +} + +func authLoginCmd() { + provider := "" + useDeviceCode := false + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + case "--device-code": + useDeviceCode = true + } + } + + if provider == "" { + fmt.Println("Error: --provider is required") + fmt.Println("Supported providers: openai, anthropic") + return + } + + switch provider { + case "openai": + authLoginOpenAI(useDeviceCode) + case "anthropic": + authLoginPasteToken(provider) + default: + fmt.Printf("Unsupported provider: %s\n", provider) + fmt.Println("Supported providers: openai, anthropic") + } +} + +func authLoginOpenAI(useDeviceCode bool) { + cfg := auth.OpenAIOAuthConfig() + + var cred *auth.AuthCredential + var err error + + if useDeviceCode { + cred, err = auth.LoginDeviceCode(cfg) + } else { + cred, err = auth.LoginBrowser(cfg) + } + + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential("openai", cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "oauth" + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Println("Login successful!") + if cred.AccountID != "" { + fmt.Printf("Account: %s\n", cred.AccountID) + } +} + +func authLoginPasteToken(provider string) { + cred, err := auth.LoginPasteToken(provider, os.Stdin) + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential(provider, cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "token" + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "token" + } + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Printf("Token saved for %s!\n", provider) +} + +func authLogoutCmd() { + provider := "" + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + } + } + + if provider != "" { + if err := auth.DeleteCredential(provider); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "" + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "" + } + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Printf("Logged out from %s\n", provider) + } else { + if err := auth.DeleteAllCredentials(); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "" + appCfg.Providers.Anthropic.AuthMethod = "" + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Println("Logged out from all providers") + } +} + +func authStatusCmd() { + store, err := auth.LoadStore() + if err != nil { + fmt.Printf("Error loading auth store: %v\n", err) + return + } + + if len(store.Credentials) == 0 { + fmt.Println("No authenticated providers.") + fmt.Println("Run: picoclaw auth login --provider ") + return + } + + fmt.Println("\nAuthenticated Providers:") + fmt.Println("------------------------") + for provider, cred := range store.Credentials { + status := "active" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + + fmt.Printf(" %s:\n", provider) + fmt.Printf(" Method: %s\n", cred.AuthMethod) + fmt.Printf(" Status: %s\n", status) + if cred.AccountID != "" { + fmt.Printf(" Account: %s\n", cred.AccountID) + } + if !cred.ExpiresAt.IsZero() { + fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) + } } } diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go new file mode 100644 index 0000000..94a79a6 --- /dev/null +++ b/pkg/auth/oauth.go @@ -0,0 +1,358 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os/exec" + "runtime" + "strings" + "time" +) + +type OAuthProviderConfig struct { + Issuer string + ClientID string + Scopes string + Port int +} + +func OpenAIOAuthConfig() OAuthProviderConfig { + return OAuthProviderConfig{ + Issuer: "https://auth.openai.com", + ClientID: "app_EMoamEEZ73f0CkXaXp7hrann", + Scopes: "openid profile email offline_access", + Port: 1455, + } +} + +func generateState() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { + pkce, err := GeneratePKCE() + if err != nil { + return nil, fmt.Errorf("generating PKCE: %w", err) + } + + state, err := generateState() + if err != nil { + return nil, fmt.Errorf("generating state: %w", err) + } + + redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port) + + authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI) + + resultCh := make(chan callbackResult, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + resultCh <- callbackResult{err: fmt.Errorf("state mismatch")} + http.Error(w, "State mismatch", http.StatusBadRequest) + return + } + + code := r.URL.Query().Get("code") + if code == "" { + errMsg := r.URL.Query().Get("error") + resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)} + http.Error(w, "No authorization code received", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, "

Authentication successful!

You can close this window.

") + resultCh <- callbackResult{code: code} + }) + + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port)) + if err != nil { + return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err) + } + + server := &http.Server{Handler: mux} + go server.Serve(listener) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + server.Shutdown(ctx) + }() + + if err := openBrowser(authURL); err != nil { + fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) + } + + fmt.Println("Waiting for authentication in browser...") + + select { + case result := <-resultCh: + if result.err != nil { + return nil, result.err + } + return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI) + case <-time.After(5 * time.Minute): + return nil, fmt.Errorf("authentication timed out after 5 minutes") + } +} + +type callbackResult struct { + code string + err error +} + +func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "client_id": cfg.ClientID, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/usercode", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, fmt.Errorf("requesting device code: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device code request failed: %s", string(body)) + } + + var deviceResp struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + Interval int `json:"interval"` + } + if err := json.Unmarshal(body, &deviceResp); err != nil { + return nil, fmt.Errorf("parsing device code response: %w", err) + } + + if deviceResp.Interval < 1 { + deviceResp.Interval = 5 + } + + fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n", + cfg.Issuer, deviceResp.UserCode) + + deadline := time.After(15 * time.Minute) + ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second) + defer ticker.Stop() + + for { + select { + case <-deadline: + return nil, fmt.Errorf("device code authentication timed out after 15 minutes") + case <-ticker.C: + cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode) + if err != nil { + continue + } + if cred != nil { + return cred, nil + } + } + } +} + +func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "device_auth_id": deviceAuthID, + "user_code": userCode, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/token", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("pending") + } + + body, _ := io.ReadAll(resp.Body) + + var tokenResp struct { + AuthorizationCode string `json:"authorization_code"` + CodeChallenge string `json:"code_challenge"` + CodeVerifier string `json:"code_verifier"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, err + } + + redirectURI := cfg.Issuer + "/deviceauth/callback" + return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI) +} + +func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) { + if cred.RefreshToken == "" { + return nil, fmt.Errorf("no refresh token available") + } + + data := url.Values{ + "client_id": {cfg.ClientID}, + "grant_type": {"refresh_token"}, + "refresh_token": {cred.RefreshToken}, + "scope": {"openid profile email"}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token refresh failed: %s", string(body)) + } + + return parseTokenResponse(body, cred.Provider) +} + +func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + return buildAuthorizeURL(cfg, pkce, state, redirectURI) +} + +func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + params := url.Values{ + "response_type": {"code"}, + "client_id": {cfg.ClientID}, + "redirect_uri": {redirectURI}, + "scope": {cfg.Scopes}, + "code_challenge": {pkce.CodeChallenge}, + "code_challenge_method": {"S256"}, + "state": {state}, + } + return cfg.Issuer + "/authorize?" + params.Encode() +} + +func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) { + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "client_id": {cfg.ClientID}, + "code_verifier": {codeVerifier}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("exchanging code for tokens: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange failed: %s", string(body)) + } + + return parseTokenResponse(body, "openai") +} + +func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + IDToken string `json:"id_token"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("parsing token response: %w", err) + } + + if tokenResp.AccessToken == "" { + return nil, fmt.Errorf("no access token in response") + } + + var expiresAt time.Time + if tokenResp.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + cred := &AuthCredential{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: expiresAt, + Provider: provider, + AuthMethod: "oauth", + } + + if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { + cred.AccountID = accountID + } + + return cred, nil +} + +func extractAccountID(accessToken string) string { + parts := strings.Split(accessToken, ".") + if len(parts) < 2 { + return "" + } + + payload := parts[1] + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64URLDecode(payload) + if err != nil { + return "" + } + + var claims map[string]interface{} + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + + if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { + if accountID, ok := authClaim["chatgpt_account_id"].(string); ok { + return accountID + } + } + + return "" +} + +func base64URLDecode(s string) ([]byte, error) { + s = strings.NewReplacer("-", "+", "_", "/").Replace(s) + return base64.StdEncoding.DecodeString(s) +} + +func openBrowser(url string) error { + switch runtime.GOOS { + case "darwin": + return exec.Command("open", url).Start() + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("cmd", "/c", "start", url).Start() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go new file mode 100644 index 0000000..00b4c60 --- /dev/null +++ b/pkg/auth/oauth_test.go @@ -0,0 +1,199 @@ +package auth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestBuildAuthorizeURL(t *testing.T) { + cfg := OAuthProviderConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client-id", + Scopes: "openid profile", + Port: 1455, + } + pkce := PKCECodes{ + CodeVerifier: "test-verifier", + CodeChallenge: "test-challenge", + } + + u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") + + if !strings.HasPrefix(u, "https://auth.example.com/authorize?") { + t.Errorf("URL does not start with expected prefix: %s", u) + } + if !strings.Contains(u, "client_id=test-client-id") { + t.Error("URL missing client_id") + } + if !strings.Contains(u, "code_challenge=test-challenge") { + t.Error("URL missing code_challenge") + } + if !strings.Contains(u, "code_challenge_method=S256") { + t.Error("URL missing code_challenge_method") + } + if !strings.Contains(u, "state=test-state") { + t.Error("URL missing state") + } + if !strings.Contains(u, "response_type=code") { + t.Error("URL missing response_type") + } +} + +func TestParseTokenResponse(t *testing.T) { + resp := map[string]interface{}{ + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": "test-id-token", + } + body, _ := json.Marshal(resp) + + cred, err := parseTokenResponse(body, "openai") + if err != nil { + t.Fatalf("parseTokenResponse() error: %v", err) + } + + if cred.AccessToken != "test-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token") + } + if cred.RefreshToken != "test-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token") + } + if cred.Provider != "openai" { + t.Errorf("Provider = %q, want %q", cred.Provider, "openai") + } + if cred.AuthMethod != "oauth" { + t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth") + } + if cred.ExpiresAt.IsZero() { + t.Error("ExpiresAt should not be zero") + } +} + +func TestParseTokenResponseNoAccessToken(t *testing.T) { + body := []byte(`{"refresh_token": "test"}`) + _, err := parseTokenResponse(body, "openai") + if err == nil { + t.Error("expected error for missing access_token") + } +} + +func TestExchangeCodeForTokens(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + Scopes: "openid", + Port: 1455, + } + + cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback") + if err != nil { + t.Fatalf("exchangeCodeForTokens() error: %v", err) + } + + if cred.AccessToken != "mock-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token") + } +} + +func TestRefreshAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "refresh_token" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "refreshed-access-token", + "refresh_token": "refreshed-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + } + + cred := &AuthCredential{ + AccessToken: "old-token", + RefreshToken: "old-refresh-token", + Provider: "openai", + AuthMethod: "oauth", + } + + refreshed, err := RefreshAccessToken(cred, cfg) + if err != nil { + t.Fatalf("RefreshAccessToken() error: %v", err) + } + + if refreshed.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token") + } + if refreshed.RefreshToken != "refreshed-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token") + } +} + +func TestRefreshAccessTokenNoRefreshToken(t *testing.T) { + cfg := OpenAIOAuthConfig() + cred := &AuthCredential{ + AccessToken: "old-token", + Provider: "openai", + AuthMethod: "oauth", + } + + _, err := RefreshAccessToken(cred, cfg) + if err == nil { + t.Error("expected error for missing refresh token") + } +} + +func TestOpenAIOAuthConfig(t *testing.T) { + cfg := OpenAIOAuthConfig() + if cfg.Issuer != "https://auth.openai.com" { + t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com") + } + if cfg.ClientID == "" { + t.Error("ClientID is empty") + } + if cfg.Port != 1455 { + t.Errorf("Port = %d, want 1455", cfg.Port) + } +} diff --git a/pkg/auth/pkce.go b/pkg/auth/pkce.go new file mode 100644 index 0000000..499daf8 --- /dev/null +++ b/pkg/auth/pkce.go @@ -0,0 +1,29 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" +) + +type PKCECodes struct { + CodeVerifier string + CodeChallenge string +} + +func GeneratePKCE() (PKCECodes, error) { + buf := make([]byte, 64) + if _, err := rand.Read(buf); err != nil { + return PKCECodes{}, err + } + + verifier := base64.RawURLEncoding.EncodeToString(buf) + + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(hash[:]) + + return PKCECodes{ + CodeVerifier: verifier, + CodeChallenge: challenge, + }, nil +} diff --git a/pkg/auth/pkce_test.go b/pkg/auth/pkce_test.go new file mode 100644 index 0000000..74ed573 --- /dev/null +++ b/pkg/auth/pkce_test.go @@ -0,0 +1,51 @@ +package auth + +import ( + "crypto/sha256" + "encoding/base64" + "testing" +) + +func TestGeneratePKCE(t *testing.T) { + codes, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes.CodeVerifier == "" { + t.Fatal("CodeVerifier is empty") + } + if codes.CodeChallenge == "" { + t.Fatal("CodeChallenge is empty") + } + + verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier) + if err != nil { + t.Fatalf("CodeVerifier is not valid base64url: %v", err) + } + if len(verifierBytes) != 64 { + t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes)) + } + + hash := sha256.Sum256([]byte(codes.CodeVerifier)) + expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) + if codes.CodeChallenge != expectedChallenge { + t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge) + } +} + +func TestGeneratePKCEUniqueness(t *testing.T) { + codes1, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + codes2, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes1.CodeVerifier == codes2.CodeVerifier { + t.Error("two GeneratePKCE() calls produced identical verifiers") + } +} diff --git a/pkg/auth/store.go b/pkg/auth/store.go new file mode 100644 index 0000000..2072492 --- /dev/null +++ b/pkg/auth/store.go @@ -0,0 +1,112 @@ +package auth + +import ( + "encoding/json" + "os" + "path/filepath" + "time" +) + +type AuthCredential struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + AccountID string `json:"account_id,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + Provider string `json:"provider"` + AuthMethod string `json:"auth_method"` +} + +type AuthStore struct { + Credentials map[string]*AuthCredential `json:"credentials"` +} + +func (c *AuthCredential) IsExpired() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().After(c.ExpiresAt) +} + +func (c *AuthCredential) NeedsRefresh() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) +} + +func authFilePath() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, ".picoclaw", "auth.json") +} + +func LoadStore() (*AuthStore, error) { + path := authFilePath() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil + } + return nil, err + } + + var store AuthStore + if err := json.Unmarshal(data, &store); err != nil { + return nil, err + } + if store.Credentials == nil { + store.Credentials = make(map[string]*AuthCredential) + } + return &store, nil +} + +func SaveStore(store *AuthStore) error { + path := authFilePath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + data, err := json.MarshalIndent(store, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0600) +} + +func GetCredential(provider string) (*AuthCredential, error) { + store, err := LoadStore() + if err != nil { + return nil, err + } + cred, ok := store.Credentials[provider] + if !ok { + return nil, nil + } + return cred, nil +} + +func SetCredential(provider string, cred *AuthCredential) error { + store, err := LoadStore() + if err != nil { + return err + } + store.Credentials[provider] = cred + return SaveStore(store) +} + +func DeleteCredential(provider string) error { + store, err := LoadStore() + if err != nil { + return err + } + delete(store.Credentials, provider) + return SaveStore(store) +} + +func DeleteAllCredentials() error { + path := authFilePath() + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/pkg/auth/store_test.go b/pkg/auth/store_test.go new file mode 100644 index 0000000..d96b460 --- /dev/null +++ b/pkg/auth/store_test.go @@ -0,0 +1,189 @@ +package auth + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestAuthCredentialIsExpired(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"future", time.Now().Add(time.Hour), false}, + {"past", time.Now().Add(-time.Hour), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.IsExpired(); got != tt.want { + t.Errorf("IsExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthCredentialNeedsRefresh(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"far future", time.Now().Add(time.Hour), false}, + {"within 5 min", time.Now().Add(3 * time.Minute), true}, + {"already expired", time.Now().Add(-time.Minute), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.NeedsRefresh(); got != tt.want { + t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStoreRoundtrip(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + AccountID: "acct-123", + ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second), + Provider: "openai", + AuthMethod: "oauth", + } + + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded == nil { + t.Fatal("GetCredential() returned nil") + } + if loaded.AccessToken != cred.AccessToken { + t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken) + } + if loaded.RefreshToken != cred.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken) + } + if loaded.Provider != cred.Provider { + t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider) + } +} + +func TestStoreFilePermissions(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "secret-token", + Provider: "openai", + AuthMethod: "oauth", + } + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + perm := info.Mode().Perm() + if perm != 0600 { + t.Errorf("file permissions = %o, want 0600", perm) + } +} + +func TestStoreMultiProvider(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"} + anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"} + + if err := SetCredential("openai", openaiCred); err != nil { + t.Fatalf("SetCredential(openai) error: %v", err) + } + if err := SetCredential("anthropic", anthropicCred); err != nil { + t.Fatalf("SetCredential(anthropic) error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential(openai) error: %v", err) + } + if loaded.AccessToken != "openai-token" { + t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token") + } + + loaded, err = GetCredential("anthropic") + if err != nil { + t.Fatalf("GetCredential(anthropic) error: %v", err) + } + if loaded.AccessToken != "anthropic-token" { + t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token") + } +} + +func TestDeleteCredential(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"} + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + if err := DeleteCredential("openai"); err != nil { + t.Fatalf("DeleteCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded != nil { + t.Error("expected nil after delete") + } +} + +func TestLoadStoreEmpty(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + store, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if store == nil { + t.Fatal("LoadStore() returned nil") + } + if len(store.Credentials) != 0 { + t.Errorf("expected empty credentials, got %d", len(store.Credentials)) + } +} diff --git a/pkg/auth/token.go b/pkg/auth/token.go new file mode 100644 index 0000000..a5a13ff --- /dev/null +++ b/pkg/auth/token.go @@ -0,0 +1,43 @@ +package auth + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) { + fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider)) + fmt.Print("> ") + + scanner := bufio.NewScanner(r) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading token: %w", err) + } + return nil, fmt.Errorf("no input received") + } + + token := strings.TrimSpace(scanner.Text()) + if token == "" { + return nil, fmt.Errorf("token cannot be empty") + } + + return &AuthCredential{ + AccessToken: token, + Provider: provider, + AuthMethod: "token", + }, nil +} + +func providerDisplayName(provider string) string { + switch provider { + case "anthropic": + return "console.anthropic.com" + case "openai": + return "platform.openai.com" + default: + return provider + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 5b9c2b5..7fc6253 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -99,8 +99,9 @@ type ProvidersConfig struct { } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` } type GatewayConfig struct { diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 12909df..dab6132 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -15,13 +15,16 @@ import ( "net/http" "strings" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) type HTTPProvider struct { - apiKey string - apiBase string - httpClient *http.Client + apiKey string + apiBase string + httpClient *http.Client + tokenSource func() (string, error) + accountID string } func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider { @@ -73,9 +76,17 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - authHeader := "Bearer " + p.apiKey - req.Header.Set("Authorization", authHeader) + if p.tokenSource != nil { + token, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("failed to get auth token: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + if p.accountID != "" { + req.Header.Set("Chatgpt-Account-Id", p.accountID) + } + } else if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) } resp, err := p.httpClient.Do(req) @@ -170,6 +181,47 @@ func (p *HTTPProvider) GetDefaultModel() string { return "" } +func createOAuthTokenSource(provider string) func() (string, error) { + return func() (string, error) { + cred, err := auth.GetCredential(provider) + if err != nil { + return "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", fmt.Errorf("no OAuth credentials for %s. Run: picoclaw auth login --provider %s", provider, provider) + } + + if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.OpenAIOAuthConfig() + refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) + if err != nil { + return "", fmt.Errorf("refreshing token: %w", err) + } + if err := auth.SetCredential(provider, refreshed); err != nil { + return "", fmt.Errorf("saving refreshed token: %w", err) + } + return refreshed.AccessToken, nil + } + + return cred.AccessToken, nil + } +} + +func createAuthProvider(providerName string, apiBase string) (LLMProvider, error) { + cred, err := auth.GetCredential(providerName) + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for %s. Run: picoclaw auth login --provider %s", providerName, providerName) + } + + p := NewHTTPProvider(cred.AccessToken, apiBase) + p.tokenSource = createOAuthTokenSource(providerName) + p.accountID = cred.AccountID + return p, nil +} + func CreateProvider(cfg *config.Config) (LLMProvider, error) { model := cfg.Agents.Defaults.Model @@ -186,14 +238,28 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiBase = "https://openrouter.ai/api/v1" } - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && cfg.Providers.Anthropic.APIKey != "": + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + ab := cfg.Providers.Anthropic.APIBase + if ab == "" { + ab = "https://api.anthropic.com/v1" + } + return createAuthProvider("anthropic", ab) + } apiKey = cfg.Providers.Anthropic.APIKey apiBase = cfg.Providers.Anthropic.APIBase if apiBase == "" { apiBase = "https://api.anthropic.com/v1" } - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && cfg.Providers.OpenAI.APIKey != "": + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + ab := cfg.Providers.OpenAI.APIBase + if ab == "" { + ab = "https://api.openai.com/v1" + } + return createAuthProvider("openai", ab) + } apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase if apiBase == "" { From fbad753b2ac4df2b39fd77ec20ba044167349350 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Wed, 11 Feb 2026 13:27:59 -0600 Subject: [PATCH 2/3] feat(providers): add SDK-based providers for subscription OAuth login Add ClaudeProvider (anthropic-sdk-go) and CodexProvider (openai-go) that use the correct subscription endpoints and API formats: - CodexProvider: chatgpt.com/backend-api/codex/responses (Responses API) with OAuth Bearer auth and Chatgpt-Account-Id header - ClaudeProvider: api.anthropic.com/v1/messages (Messages API) with Authorization: Bearer token auth Update CreateProvider() routing to use new SDK-based providers when auth_method is "oauth" or "token", removing the stopgap that sent subscription tokens to pay-per-token endpoints. Closes #18 Co-Authored-By: Claude Opus 4.6 --- go.mod | 3 + go.sum | 8 + pkg/providers/claude_provider.go | 207 ++++++++++++++++++++ pkg/providers/claude_provider_test.go | 210 ++++++++++++++++++++ pkg/providers/codex_provider.go | 248 ++++++++++++++++++++++++ pkg/providers/codex_provider_test.go | 264 ++++++++++++++++++++++++++ pkg/providers/http_provider.go | 78 ++------ 7 files changed, 960 insertions(+), 58 deletions(-) create mode 100644 pkg/providers/claude_provider.go create mode 100644 pkg/providers/claude_provider_test.go create mode 100644 pkg/providers/codex_provider.go create mode 100644 pkg/providers/codex_provider_test.go diff --git a/go.mod b/go.mod index 832f1e8..54c73fc 100644 --- a/go.mod +++ b/go.mod @@ -16,12 +16,15 @@ require ( ) require ( + github.com/anthropics/anthropic-sdk-go v1.22.1 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/openai/openai-go v1.12.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect golang.org/x/sync v0.19.0 // indirect diff --git a/go.sum b/go.sum index f1ce926..96ff62e 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc= github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg= +github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0= +github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA= @@ -72,6 +74,8 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= +github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= +github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -86,9 +90,11 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk= github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= @@ -97,6 +103,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go new file mode 100644 index 0000000..ae6aca9 --- /dev/null +++ b/pkg/providers/claude_provider.go @@ -0,0 +1,207 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type ClaudeProvider struct { + client *anthropic.Client + tokenSource func() (string, error) +} + +func NewClaudeProvider(token string) *ClaudeProvider { + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL("https://api.anthropic.com"), + ) + return &ClaudeProvider{client: &client} +} + +func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { + p := NewClaudeProvider(token) + p.tokenSource = tokenSource + return p +} + +func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildClaudeParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseClaudeResponse(resp), nil +} + +func (p *ClaudeProvider) GetDefaultModel() string { + return "claude-sonnet-4-5-20250929" +} + +func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForClaude(tools) + } + + return params, nil +} + +func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]interface{}); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]interface{} + if err := json.Unmarshal(tu.Input, &args); err != nil { + args = map[string]interface{}{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} + +func createClaudeTokenSource() func() (string, error) { + return func() (string, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return cred.AccessToken, nil + } +} diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go new file mode 100644 index 0000000..bbad2d2 --- /dev/null +++ b/pkg/providers/claude_provider_test.go @@ -0,0 +1,210 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildClaudeParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4-5-20250929" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]interface{}{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildClaudeParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"city"}, + }, + }, + }, + } + params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseClaudeResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseClaudeResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseClaudeResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseClaudeResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestClaudeProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewClaudeProvider("test-token") + provider.client = createAnthropicTestClient(server.URL, "test-token") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestClaudeProvider_GetDefaultModel(t *testing.T) { + p := NewClaudeProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go new file mode 100644 index 0000000..a17ae22 --- /dev/null +++ b/pkg/providers/codex_provider.go @@ -0,0 +1,248 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/responses" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type CodexProvider struct { + client *openai.Client + accountID string + tokenSource func() (string, string, error) +} + +func NewCodexProvider(token, accountID string) *CodexProvider { + opts := []option.RequestOption{ + option.WithBaseURL("https://chatgpt.com/backend-api/codex"), + option.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) + } + client := openai.NewClient(opts...) + return &CodexProvider{ + client: &client, + accountID: accountID, + } +} + +func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider { + p := NewCodexProvider(token, accountID) + p.tokenSource = tokenSource + return p +} + +func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, accID, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAPIKey(tok)) + if accID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID)) + } + } + + params := buildCodexParams(messages, tools, model, options) + + resp, err := p.client.Responses.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("codex API call: %w", err) + } + + return parseCodexResponse(resp), nil +} + +func (p *CodexProvider) GetDefaultModel() string { + return "gpt-4o" +} + +func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { + var inputItems responses.ResponseInputParam + var instructions string + + for _, msg := range messages { + switch msg.Role { + case "system": + instructions = msg.Content + case "user": + if msg.ToolCallID != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: msg.Content, + }, + }) + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleUser, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + if msg.Content != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + for _, tc := range msg.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCall: &responses.ResponseFunctionToolCallParam{ + CallID: tc.ID, + Name: tc.Name, + Arguments: string(argsJSON), + }, + }) + } + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "tool": + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: msg.Content, + }, + }) + } + } + + params := responses.ResponseNewParams{ + Model: model, + Input: responses.ResponseNewParamsInputUnion{ + OfInputItemList: inputItems, + }, + Store: openai.Opt(false), + } + + if instructions != "" { + params.Instructions = openai.Opt(instructions) + } + + if maxTokens, ok := options["max_tokens"].(int); ok { + params.MaxOutputTokens = openai.Opt(int64(maxTokens)) + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = openai.Opt(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForCodex(tools) + } + + return params +} + +func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { + result := make([]responses.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + ft := responses.FunctionToolParam{ + Name: t.Function.Name, + Parameters: t.Function.Parameters, + Strict: openai.Opt(false), + } + if t.Function.Description != "" { + ft.Description = openai.Opt(t.Function.Description) + } + result = append(result, responses.ToolUnionParam{OfFunction: &ft}) + } + return result +} + +func parseCodexResponse(resp *responses.Response) *LLMResponse { + var content strings.Builder + var toolCalls []ToolCall + + for _, item := range resp.Output { + switch item.Type { + case "message": + for _, c := range item.Content { + if c.Type == "output_text" { + content.WriteString(c.Text) + } + } + case "function_call": + var args map[string]interface{} + if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil { + args = map[string]interface{}{"raw": item.Arguments} + } + toolCalls = append(toolCalls, ToolCall{ + ID: item.CallID, + Name: item.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + if resp.Status == "incomplete" { + finishReason = "length" + } + + var usage *UsageInfo + if resp.Usage.TotalTokens > 0 { + usage = &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + } + + return &LLMResponse{ + Content: content.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + } +} + +func createCodexTokenSource() func() (string, string, error) { + return func() (string, string, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return "", "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + + if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.OpenAIOAuthConfig() + refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) + if err != nil { + return "", "", fmt.Errorf("refreshing token: %w", err) + } + if err := auth.SetCredential("openai", refreshed); err != nil { + return "", "", fmt.Errorf("saving refreshed token: %w", err) + } + return refreshed.AccessToken, refreshed.AccountID, nil + } + + return cred.AccessToken, cred.AccountID, nil + } +} diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go new file mode 100644 index 0000000..e68a70b --- /dev/null +++ b/pkg/providers/codex_provider_test.go @@ -0,0 +1,264 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/openai/openai-go" + openaiopt "github.com/openai/openai-go/option" + "github.com/openai/openai-go/responses" +) + +func TestBuildCodexParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ + "max_tokens": 2048, + }) + if params.Model != "gpt-4o" { + t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") + } +} + +func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if !params.Instructions.Valid() { + t.Fatal("Instructions should be set") + } + if params.Instructions.Or("") != "You are helpful" { + t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful") + } +} + +func TestBuildCodexParams_ToolCallConversation(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []ToolCall{ + {ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}}, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if params.Input.OfInputItemList == nil { + t.Fatal("Input.OfInputItemList should not be nil") + } + if len(params.Input.OfInputItemList) != 3 { + t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList)) + } +} + +func TestBuildCodexParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + } + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}) + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } + if params.Tools[0].OfFunction == nil { + t.Fatal("Tool should be a function tool") + } + if params.Tools[0].OfFunction.Name != "get_weather" { + t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather") + } +} + +func TestBuildCodexParams_StoreIsFalse(t *testing.T) { + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}) + if !params.Store.Valid() || params.Store.Or(true) != false { + t.Error("Store should be explicitly set to false") + } +} + +func TestParseCodexResponse_TextOutput(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": [ + {"type": "output_text", "text": "Hello there!"} + ] + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if result.Content != "Hello there!" { + t.Errorf("Content = %q, want %q", result.Content, "Hello there!") + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } + if result.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens) + } +} + +func TestParseCodexResponse_FunctionCall(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "fc_1", + "type": "function_call", + "call_id": "call_abc", + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + "status": "completed" + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 8, + "total_tokens": 18, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if len(result.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls)) + } + tc := result.ToolCalls[0] + if tc.Name != "get_weather" { + t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather") + } + if tc.ID != "call_abc" { + t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc") + } + if tc.Arguments["city"] != "SF" { + t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"]) + } + if result.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls") + } +} + +func TestCodexProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Header.Get("Chatgpt-Account-Id") != "acc-123" { + http.Error(w, "missing account id", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 12, + "output_tokens": 6, + "total_tokens": 18, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.TotalTokens != 18 { + t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens) + } +} + +func TestCodexProvider_GetDefaultModel(t *testing.T) { + p := NewCodexProvider("test-token", "") + if got := p.GetDefaultModel(); got != "gpt-4o" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o") + } +} + +func createOpenAITestClient(baseURL, token, accountID string) *openai.Client { + opts := []openaiopt.RequestOption{ + openaiopt.WithBaseURL(baseURL), + openaiopt.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID)) + } + c := openai.NewClient(opts...) + return &c +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index dab6132..f63c68c 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -20,11 +20,9 @@ import ( ) type HTTPProvider struct { - apiKey string - apiBase string - httpClient *http.Client - tokenSource func() (string, error) - accountID string + apiKey string + apiBase string + httpClient *http.Client } func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider { @@ -76,16 +74,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } req.Header.Set("Content-Type", "application/json") - if p.tokenSource != nil { - token, err := p.tokenSource() - if err != nil { - return nil, fmt.Errorf("failed to get auth token: %w", err) - } - req.Header.Set("Authorization", "Bearer "+token) - if p.accountID != "" { - req.Header.Set("Chatgpt-Account-Id", p.accountID) - } - } else if p.apiKey != "" { + if p.apiKey != "" { req.Header.Set("Authorization", "Bearer "+p.apiKey) } @@ -181,45 +170,26 @@ func (p *HTTPProvider) GetDefaultModel() string { return "" } -func createOAuthTokenSource(provider string) func() (string, error) { - return func() (string, error) { - cred, err := auth.GetCredential(provider) - if err != nil { - return "", fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return "", fmt.Errorf("no OAuth credentials for %s. Run: picoclaw auth login --provider %s", provider, provider) - } - - if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" { - oauthCfg := auth.OpenAIOAuthConfig() - refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) - if err != nil { - return "", fmt.Errorf("refreshing token: %w", err) - } - if err := auth.SetCredential(provider, refreshed); err != nil { - return "", fmt.Errorf("saving refreshed token: %w", err) - } - return refreshed.AccessToken, nil - } - - return cred.AccessToken, nil - } -} - -func createAuthProvider(providerName string, apiBase string) (LLMProvider, error) { - cred, err := auth.GetCredential(providerName) +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("anthropic") if err != nil { return nil, fmt.Errorf("loading auth credentials: %w", err) } if cred == nil { - return nil, fmt.Errorf("no credentials for %s. Run: picoclaw auth login --provider %s", providerName, providerName) + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} - p := NewHTTPProvider(cred.AccessToken, apiBase) - p.tokenSource = createOAuthTokenSource(providerName) - p.accountID = cred.AccountID - return p, nil +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil } func CreateProvider(cfg *config.Config) (LLMProvider, error) { @@ -240,11 +210,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - ab := cfg.Providers.Anthropic.APIBase - if ab == "" { - ab = "https://api.anthropic.com/v1" - } - return createAuthProvider("anthropic", ab) + return createClaudeAuthProvider() } apiKey = cfg.Providers.Anthropic.APIKey apiBase = cfg.Providers.Anthropic.APIBase @@ -254,11 +220,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - ab := cfg.Providers.OpenAI.APIBase - if ab == "" { - ab = "https://api.openai.com/v1" - } - return createAuthProvider("openai", ab) + return createCodexAuthProvider() } apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase From 83f6e44b024f8b4c77168ecdd2566d04ea215be4 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Wed, 11 Feb 2026 13:39:19 -0600 Subject: [PATCH 3/3] chore(deps): upgrade openai-go from v1.12.0 to v3.21.0 Update to latest major version of the official OpenAI Go SDK. Fix breaking change: FunctionCallOutput.Output is now a union type (ResponseInputItemFunctionCallOutputOutputUnionParam) instead of string. Co-Authored-By: Claude Opus 4.6 --- go.mod | 4 ++-- go.sum | 7 +++---- pkg/providers/codex_provider.go | 10 +++++----- pkg/providers/codex_provider_test.go | 6 +++--- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index 54c73fc..1765efd 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.24.0 require ( github.com/adhocore/gronx v1.19.6 + github.com/anthropics/anthropic-sdk-go v1.22.1 github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 @@ -11,16 +12,15 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 + github.com/openai/openai-go/v3 v3.21.0 github.com/tencent-connect/botgo v0.2.1 golang.org/x/oauth2 v0.35.0 ) require ( - github.com/anthropics/anthropic-sdk-go v1.22.1 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/openai/openai-go v1.12.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect diff --git a/go.sum b/go.sum index 96ff62e..459e661 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,8 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= -github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= -github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= +github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc= +github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -88,9 +88,8 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk= github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index a17ae22..3463389 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -6,9 +6,9 @@ import ( "fmt" "strings" - "github.com/openai/openai-go" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/responses" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" "github.com/sipeed/picoclaw/pkg/auth" ) @@ -79,7 +79,7 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ CallID: msg.ToolCallID, - Output: msg.Content, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, }, }) } else { @@ -122,7 +122,7 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ CallID: msg.ToolCallID, - Output: msg.Content, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, }, }) } diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index e68a70b..605183d 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -6,9 +6,9 @@ import ( "net/http/httptest" "testing" - "github.com/openai/openai-go" - openaiopt "github.com/openai/openai-go/option" - "github.com/openai/openai-go/responses" + "github.com/openai/openai-go/v3" + openaiopt "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" ) func TestBuildCodexParams_BasicMessage(t *testing.T) {