planned-date: 2026-02-12 why: Device-code login was failing because the interval field can arrive as a quoted number, which breaks strict integer decoding and blocks login in shell-only environments. what: Added flexible interval parsing for numeric or quoted values, wired LoginDeviceCode to the parser, printed the browser auth URL before waiting, and added parser tests for numeric, quoted, and invalid interval payloads. verification: c:\projects\toolchains\go\bin\go.exe test ./pkg/auth -run Test(ParseDeviceCodeResponse|BuildAuthorizeURL)
410 lines
10 KiB
Go
410 lines
10 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os/exec"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type OAuthProviderConfig struct {
|
|
Issuer string
|
|
ClientID string
|
|
Scopes string
|
|
Port int
|
|
}
|
|
|
|
func OpenAIOAuthConfig() OAuthProviderConfig {
|
|
return OAuthProviderConfig{
|
|
Issuer: "https://auth.openai.com",
|
|
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
|
Scopes: "openid profile email offline_access",
|
|
Port: 1455,
|
|
}
|
|
}
|
|
|
|
func generateState() (string, error) {
|
|
buf := make([]byte, 32)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(buf), nil
|
|
}
|
|
|
|
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
|
pkce, err := GeneratePKCE()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating PKCE: %w", err)
|
|
}
|
|
|
|
state, err := generateState()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating state: %w", err)
|
|
}
|
|
|
|
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
|
|
|
|
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
|
|
|
resultCh := make(chan callbackResult, 1)
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Query().Get("state") != state {
|
|
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
|
|
http.Error(w, "State mismatch", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
code := r.URL.Query().Get("code")
|
|
if code == "" {
|
|
errMsg := r.URL.Query().Get("error")
|
|
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
|
|
http.Error(w, "No authorization code received", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "text/html")
|
|
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
|
|
resultCh <- callbackResult{code: code}
|
|
})
|
|
|
|
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
|
|
}
|
|
|
|
server := &http.Server{Handler: mux}
|
|
go server.Serve(listener)
|
|
defer func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
server.Shutdown(ctx)
|
|
}()
|
|
|
|
fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL)
|
|
|
|
if err := openBrowser(authURL); err != nil {
|
|
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
|
|
}
|
|
|
|
fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code")
|
|
fmt.Println("Waiting for authentication in browser...")
|
|
|
|
select {
|
|
case result := <-resultCh:
|
|
if result.err != nil {
|
|
return nil, result.err
|
|
}
|
|
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
|
|
case <-time.After(5 * time.Minute):
|
|
return nil, fmt.Errorf("authentication timed out after 5 minutes")
|
|
}
|
|
}
|
|
|
|
type callbackResult struct {
|
|
code string
|
|
err error
|
|
}
|
|
|
|
type deviceCodeResponse struct {
|
|
DeviceAuthID string
|
|
UserCode string
|
|
Interval int
|
|
}
|
|
|
|
func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) {
|
|
var raw struct {
|
|
DeviceAuthID string `json:"device_auth_id"`
|
|
UserCode string `json:"user_code"`
|
|
Interval json.RawMessage `json:"interval"`
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &raw); err != nil {
|
|
return deviceCodeResponse{}, err
|
|
}
|
|
|
|
interval, err := parseFlexibleInt(raw.Interval)
|
|
if err != nil {
|
|
return deviceCodeResponse{}, err
|
|
}
|
|
|
|
return deviceCodeResponse{
|
|
DeviceAuthID: raw.DeviceAuthID,
|
|
UserCode: raw.UserCode,
|
|
Interval: interval,
|
|
}, nil
|
|
}
|
|
|
|
func parseFlexibleInt(raw json.RawMessage) (int, error) {
|
|
if len(raw) == 0 || string(raw) == "null" {
|
|
return 0, nil
|
|
}
|
|
|
|
var interval int
|
|
if err := json.Unmarshal(raw, &interval); err == nil {
|
|
return interval, nil
|
|
}
|
|
|
|
var intervalStr string
|
|
if err := json.Unmarshal(raw, &intervalStr); err == nil {
|
|
intervalStr = strings.TrimSpace(intervalStr)
|
|
if intervalStr == "" {
|
|
return 0, nil
|
|
}
|
|
return strconv.Atoi(intervalStr)
|
|
}
|
|
|
|
return 0, fmt.Errorf("invalid integer value: %s", string(raw))
|
|
}
|
|
|
|
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
|
reqBody, _ := json.Marshal(map[string]string{
|
|
"client_id": cfg.ClientID,
|
|
})
|
|
|
|
resp, err := http.Post(
|
|
cfg.Issuer+"/api/accounts/deviceauth/usercode",
|
|
"application/json",
|
|
strings.NewReader(string(reqBody)),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("requesting device code: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
|
}
|
|
|
|
deviceResp, err := parseDeviceCodeResponse(body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing device code response: %w", err)
|
|
}
|
|
|
|
if deviceResp.Interval < 1 {
|
|
deviceResp.Interval = 5
|
|
}
|
|
|
|
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
|
|
cfg.Issuer, deviceResp.UserCode)
|
|
|
|
deadline := time.After(15 * time.Minute)
|
|
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-deadline:
|
|
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
|
|
case <-ticker.C:
|
|
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if cred != nil {
|
|
return cred, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
|
|
reqBody, _ := json.Marshal(map[string]string{
|
|
"device_auth_id": deviceAuthID,
|
|
"user_code": userCode,
|
|
})
|
|
|
|
resp, err := http.Post(
|
|
cfg.Issuer+"/api/accounts/deviceauth/token",
|
|
"application/json",
|
|
strings.NewReader(string(reqBody)),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("pending")
|
|
}
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
var tokenResp struct {
|
|
AuthorizationCode string `json:"authorization_code"`
|
|
CodeChallenge string `json:"code_challenge"`
|
|
CodeVerifier string `json:"code_verifier"`
|
|
}
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
redirectURI := cfg.Issuer + "/deviceauth/callback"
|
|
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
|
|
}
|
|
|
|
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
|
|
if cred.RefreshToken == "" {
|
|
return nil, fmt.Errorf("no refresh token available")
|
|
}
|
|
|
|
data := url.Values{
|
|
"client_id": {cfg.ClientID},
|
|
"grant_type": {"refresh_token"},
|
|
"refresh_token": {cred.RefreshToken},
|
|
"scope": {"openid profile email"},
|
|
}
|
|
|
|
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
|
}
|
|
|
|
return parseTokenResponse(body, cred.Provider)
|
|
}
|
|
|
|
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
|
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
|
}
|
|
|
|
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
|
params := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {cfg.ClientID},
|
|
"redirect_uri": {redirectURI},
|
|
"scope": {cfg.Scopes},
|
|
"code_challenge": {pkce.CodeChallenge},
|
|
"code_challenge_method": {"S256"},
|
|
"state": {state},
|
|
}
|
|
return cfg.Issuer + "/authorize?" + params.Encode()
|
|
}
|
|
|
|
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
|
data := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {code},
|
|
"redirect_uri": {redirectURI},
|
|
"client_id": {cfg.ClientID},
|
|
"code_verifier": {codeVerifier},
|
|
}
|
|
|
|
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
|
}
|
|
|
|
return parseTokenResponse(body, "openai")
|
|
}
|
|
|
|
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
|
var tokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
|
return nil, fmt.Errorf("parsing token response: %w", err)
|
|
}
|
|
|
|
if tokenResp.AccessToken == "" {
|
|
return nil, fmt.Errorf("no access token in response")
|
|
}
|
|
|
|
var expiresAt time.Time
|
|
if tokenResp.ExpiresIn > 0 {
|
|
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
|
}
|
|
|
|
cred := &AuthCredential{
|
|
AccessToken: tokenResp.AccessToken,
|
|
RefreshToken: tokenResp.RefreshToken,
|
|
ExpiresAt: expiresAt,
|
|
Provider: provider,
|
|
AuthMethod: "oauth",
|
|
}
|
|
|
|
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
|
cred.AccountID = accountID
|
|
}
|
|
|
|
return cred, nil
|
|
}
|
|
|
|
func extractAccountID(accessToken string) string {
|
|
parts := strings.Split(accessToken, ".")
|
|
if len(parts) < 2 {
|
|
return ""
|
|
}
|
|
|
|
payload := parts[1]
|
|
switch len(payload) % 4 {
|
|
case 2:
|
|
payload += "=="
|
|
case 3:
|
|
payload += "="
|
|
}
|
|
|
|
decoded, err := base64URLDecode(payload)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
|
return ""
|
|
}
|
|
|
|
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
|
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
|
return accountID
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func base64URLDecode(s string) ([]byte, error) {
|
|
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
|
|
return base64.StdEncoding.DecodeString(s)
|
|
}
|
|
|
|
func openBrowser(url string) error {
|
|
switch runtime.GOOS {
|
|
case "darwin":
|
|
return exec.Command("open", url).Start()
|
|
case "linux":
|
|
return exec.Command("xdg-open", url).Start()
|
|
case "windows":
|
|
return exec.Command("cmd", "/c", "start", url).Start()
|
|
default:
|
|
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
|
}
|
|
}
|