420 lines
11 KiB
Go
420 lines
11 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os/exec"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type OAuthProviderConfig struct {
|
|
Issuer string
|
|
ClientID string
|
|
Scopes string
|
|
Originator string
|
|
Port int
|
|
}
|
|
|
|
func OpenAIOAuthConfig() OAuthProviderConfig {
|
|
return OAuthProviderConfig{
|
|
Issuer: "https://auth.openai.com",
|
|
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
|
Scopes: "openid profile email offline_access",
|
|
Originator: "codex_cli_rs",
|
|
Port: 1455,
|
|
}
|
|
}
|
|
|
|
func generateState() (string, error) {
|
|
buf := make([]byte, 32)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(buf), nil
|
|
}
|
|
|
|
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
|
pkce, err := GeneratePKCE()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating PKCE: %w", err)
|
|
}
|
|
|
|
state, err := generateState()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating state: %w", err)
|
|
}
|
|
|
|
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
|
|
|
|
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
|
|
|
resultCh := make(chan callbackResult, 1)
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Query().Get("state") != state {
|
|
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
|
|
http.Error(w, "State mismatch", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
errMsg := r.URL.Query().Get("error")
|
|
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
|
|
http.Error(w, "No authorization code received", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/html")
|
|
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
|
|
resultCh <- callbackResult{code: code}
|
|
})
|
|
|
|
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
|
|
}
|
|
|
|
server := &http.Server{Handler: mux}
|
|
go server.Serve(listener)
|
|
defer func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
server.Shutdown(ctx)
|
|
}()
|
|
|
|
fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL)
|
|
|
|
if err := openBrowser(authURL); err != nil {
|
|
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
|
|
}
|
|
|
|
fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code")
|
|
fmt.Println("Waiting for authentication in browser...")
|
|
|
|
select {
|
|
case result := <-resultCh:
|
|
if result.err != nil {
|
|
return nil, result.err
|
|
}
|
|
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
|
|
case <-time.After(5 * time.Minute):
|
|
return nil, fmt.Errorf("authentication timed out after 5 minutes")
|
|
}
|
|
}
|
|
|
|
type callbackResult struct {
|
|
code string
|
|
err error
|
|
}
|
|
|
|
type deviceCodeResponse struct {
|
|
DeviceAuthID string
|
|
UserCode string
|
|
Interval int
|
|
}
|
|
|
|
func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) {
|
|
var raw struct {
|
|
DeviceAuthID string `json:"device_auth_id"`
|
|
UserCode string `json:"user_code"`
|
|
Interval json.RawMessage `json:"interval"`
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &raw); err != nil {
|
|
return deviceCodeResponse{}, err
|
|
}
|
|
|
|
interval, err := parseFlexibleInt(raw.Interval)
|
|
if err != nil {
|
|
return deviceCodeResponse{}, err
|
|
}
|
|
|
|
return deviceCodeResponse{
|
|
DeviceAuthID: raw.DeviceAuthID,
|
|
UserCode: raw.UserCode,
|
|
Interval: interval,
|
|
}, nil
|
|
}
|
|
|
|
func parseFlexibleInt(raw json.RawMessage) (int, error) {
|
|
if len(raw) == 0 || string(raw) == "null" {
|
|
return 0, nil
|
|
}
|
|
|
|
var interval int
|
|
if err := json.Unmarshal(raw, &interval); err == nil {
|
|
return interval, nil
|
|
}
|
|
|
|
var intervalStr string
|
|
if err := json.Unmarshal(raw, &intervalStr); err == nil {
|
|
intervalStr = strings.TrimSpace(intervalStr)
|
|
if intervalStr == "" {
|
|
return 0, nil
|
|
}
|
|
return strconv.Atoi(intervalStr)
|
|
}
|
|
|
|
return 0, fmt.Errorf("invalid integer value: %s", string(raw))
|
|
}
|
|
|
|
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
|
reqBody, _ := json.Marshal(map[string]string{
|
|
"client_id": cfg.ClientID,
|
|
})
|
|
|
|
resp, err := http.Post(
|
|
cfg.Issuer+"/api/accounts/deviceauth/usercode",
|
|
"application/json",
|
|
strings.NewReader(string(reqBody)),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("requesting device code: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
|
}
|
|
|
|
deviceResp, err := parseDeviceCodeResponse(body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing device code response: %w", err)
|
|
}
|
|
|
|
if deviceResp.Interval < 1 {
|
|
deviceResp.Interval = 5
|
|
}
|
|
|
|
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
|
|
cfg.Issuer, deviceResp.UserCode)
|
|
|
|
deadline := time.After(15 * time.Minute)
|
|
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-deadline:
|
|
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
|
|
case <-ticker.C:
|
|
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if cred != nil {
|
|
return cred, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
|
|
reqBody, _ := json.Marshal(map[string]string{
|
|
"device_auth_id": deviceAuthID,
|
|
"user_code": userCode,
|
|
})
|
|
|
|
resp, err := http.Post(
|
|
cfg.Issuer+"/api/accounts/deviceauth/token",
|
|
"application/json",
|
|
strings.NewReader(string(reqBody)),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("pending")
|
|
}
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
var tokenResp struct {
|
|
AuthorizationCode string `json:"authorization_code"`
|
|
CodeChallenge string `json:"code_challenge"`
|
|
CodeVerifier string `json:"code_verifier"`
|
|
}
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
redirectURI := cfg.Issuer + "/deviceauth/callback"
|
|
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
|
|
}
|
|
|
|
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
|
|
if cred.RefreshToken == "" {
|
|
return nil, fmt.Errorf("no refresh token available")
|
|
}
|
|
|
|
data := url.Values{
|
|
"client_id": {cfg.ClientID},
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {cred.RefreshToken},
|
|
"scope": {"openid profile email"},
|
|
}
|
|
|
|
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
|
}
|
|
|
|
return parseTokenResponse(body, cred.Provider)
|
|
}
|
|
|
|
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
|
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
|
}
|
|
|
|
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
|
params := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {cfg.ClientID},
|
|
"redirect_uri": {redirectURI},
|
|
"scope": {cfg.Scopes},
|
|
"code_challenge": {pkce.CodeChallenge},
|
|
"code_challenge_method": {"S256"},
|
|
"id_token_add_organizations": {"true"},
|
|
"codex_cli_simplified_flow": {"true"},
|
|
"state": {state},
|
|
}
|
|
if cfg.Originator != "" {
|
|
params.Set("originator", cfg.Originator)
|
|
}
|
|
return cfg.Issuer + "/oauth/authorize?" + params.Encode()
|
|
}
|
|
|
|
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
|
data := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {code},
|
|
"redirect_uri": {redirectURI},
|
|
"client_id": {cfg.ClientID},
|
|
"code_verifier": {codeVerifier},
|
|
}
|
|
|
|
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
|
}
|
|
|
|
return parseTokenResponse(body, "openai")
|
|
}
|
|
|
|
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
|
var tokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return nil, fmt.Errorf("parsing token response: %w", err)
|
|
}
|
|
|
|
if tokenResp.AccessToken == "" {
|
|
return nil, fmt.Errorf("no access token in response")
|
|
}
|
|
|
|
var expiresAt time.Time
|
|
if tokenResp.ExpiresIn > 0 {
|
|
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
|
}
|
|
|
|
cred := &AuthCredential{
|
|
AccessToken: tokenResp.AccessToken,
|
|
RefreshToken: tokenResp.RefreshToken,
|
|
ExpiresAt: expiresAt,
|
|
Provider: provider,
|
|
AuthMethod: "oauth",
|
|
}
|
|
|
|
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
|
cred.AccountID = accountID
|
|
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
|
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
|
cred.AccountID = accountID
|
|
}
|
|
|
|
return cred, nil
|
|
}
|
|
|
|
func extractAccountID(accessToken string) string {
|
|
parts := strings.Split(accessToken, ".")
|
|
if len(parts) < 2 {
|
|
return ""
|
|
}
|
|
|
|
payload := parts[1]
|
|
switch len(payload) % 4 {
|
|
case 2:
|
|
payload += "=="
|
|
case 3:
|
|
payload += "="
|
|
}
|
|
|
|
decoded, err := base64URLDecode(payload)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
|
return ""
|
|
}
|
|
|
|
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
|
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
|
return accountID
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func base64URLDecode(s string) ([]byte, error) {
|
|
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
|
|
return base64.StdEncoding.DecodeString(s)
|
|
}
|
|
|
|
func openBrowser(url string) error {
|
|
switch runtime.GOOS {
|
|
case "darwin":
|
|
return exec.Command("open", url).Start()
|
|
case "linux":
|
|
return exec.Command("xdg-open", url).Start()
|
|
case "windows":
|
|
return exec.Command("cmd", "/c", "start", url).Start()
|
|
default:
|
|
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
|
}
|
|
}
|