From 5efe8a202010fc9b6660be14ee6f2ce6e8d3df11 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Wed, 11 Feb 2026 11:41:13 -0600 Subject: [PATCH] feat(auth): add OAuth and token-based login for OpenAI and Anthropic Add `picoclaw auth` CLI command supporting: - OpenAI OAuth2 (PKCE + browser callback or device code flow) - Anthropic paste-token flow - Token storage at ~/.picoclaw/auth.json with 0600 permissions - Auto-refresh for expired OAuth tokens in provider Closes #18 Co-Authored-By: Claude Opus 4.6 --- cmd/picoclaw/main.go | 237 ++++++++++++++++++++++ pkg/auth/oauth.go | 358 +++++++++++++++++++++++++++++++++ pkg/auth/oauth_test.go | 199 ++++++++++++++++++ pkg/auth/pkce.go | 29 +++ pkg/auth/pkce_test.go | 51 +++++ pkg/auth/store.go | 112 +++++++++++ pkg/auth/store_test.go | 189 +++++++++++++++++ pkg/auth/token.go | 43 ++++ pkg/config/config.go | 5 +- pkg/providers/http_provider.go | 82 +++++++- 10 files changed, 1295 insertions(+), 10 deletions(-) create mode 100644 pkg/auth/oauth.go create mode 100644 pkg/auth/oauth_test.go create mode 100644 pkg/auth/pkce.go create mode 100644 pkg/auth/pkce_test.go create mode 100644 pkg/auth/store.go create mode 100644 pkg/auth/store_test.go create mode 100644 pkg/auth/token.go diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index c14ec58..e1128fe 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -19,6 +19,7 @@ import ( "github.com/chzyer/readline" "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" @@ -85,6 +86,8 @@ func main() { gatewayCmd() case "status": statusCmd() + case "auth": + authCmd() case "cron": cronCmd() case "skills": @@ -152,6 +155,7 @@ func printHelp() { fmt.Println("Commands:") fmt.Println(" onboard Initialize picoclaw configuration and workspace") fmt.Println(" agent Interact with the agent directly") + fmt.Println(" auth Manage authentication (login, logout, status)") fmt.Println(" gateway Start picoclaw gateway") fmt.Println(" status Show picoclaw status") fmt.Println(" cron Manage scheduled tasks") @@ -682,6 +686,239 @@ func statusCmd() { } else { fmt.Println("vLLM/Local: not set") } + + store, _ := auth.LoadStore() + if store != nil && len(store.Credentials) > 0 { + fmt.Println("\nOAuth/Token Auth:") + for provider, cred := range store.Credentials { + status := "authenticated" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status) + } + } + } +} + +func authCmd() { + if len(os.Args) < 3 { + authHelp() + return + } + + switch os.Args[2] { + case "login": + authLoginCmd() + case "logout": + authLogoutCmd() + case "status": + authStatusCmd() + default: + fmt.Printf("Unknown auth command: %s\n", os.Args[2]) + authHelp() + } +} + +func authHelp() { + fmt.Println("\nAuth commands:") + fmt.Println(" login Login via OAuth or paste token") + fmt.Println(" logout Remove stored credentials") + fmt.Println(" status Show current auth status") + fmt.Println() + fmt.Println("Login options:") + fmt.Println(" --provider Provider to login with (openai, anthropic)") + fmt.Println(" --device-code Use device code flow (for headless environments)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw auth login --provider openai") + fmt.Println(" picoclaw auth login --provider openai --device-code") + fmt.Println(" picoclaw auth login --provider anthropic") + fmt.Println(" picoclaw auth logout --provider openai") + fmt.Println(" picoclaw auth status") +} + +func authLoginCmd() { + provider := "" + useDeviceCode := false + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + case "--device-code": + useDeviceCode = true + } + } + + if provider == "" { + fmt.Println("Error: --provider is required") + fmt.Println("Supported providers: openai, anthropic") + return + } + + switch provider { + case "openai": + authLoginOpenAI(useDeviceCode) + case "anthropic": + authLoginPasteToken(provider) + default: + fmt.Printf("Unsupported provider: %s\n", provider) + fmt.Println("Supported providers: openai, anthropic") + } +} + +func authLoginOpenAI(useDeviceCode bool) { + cfg := auth.OpenAIOAuthConfig() + + var cred *auth.AuthCredential + var err error + + if useDeviceCode { + cred, err = auth.LoginDeviceCode(cfg) + } else { + cred, err = auth.LoginBrowser(cfg) + } + + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential("openai", cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "oauth" + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Println("Login successful!") + if cred.AccountID != "" { + fmt.Printf("Account: %s\n", cred.AccountID) + } +} + +func authLoginPasteToken(provider string) { + cred, err := auth.LoginPasteToken(provider, os.Stdin) + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential(provider, cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "token" + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "token" + } + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Printf("Token saved for %s!\n", provider) +} + +func authLogoutCmd() { + provider := "" + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + } + } + + if provider != "" { + if err := auth.DeleteCredential(provider); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "" + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "" + } + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Printf("Logged out from %s\n", provider) + } else { + if err := auth.DeleteAllCredentials(); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "" + appCfg.Providers.Anthropic.AuthMethod = "" + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Println("Logged out from all providers") + } +} + +func authStatusCmd() { + store, err := auth.LoadStore() + if err != nil { + fmt.Printf("Error loading auth store: %v\n", err) + return + } + + if len(store.Credentials) == 0 { + fmt.Println("No authenticated providers.") + fmt.Println("Run: picoclaw auth login --provider ") + return + } + + fmt.Println("\nAuthenticated Providers:") + fmt.Println("------------------------") + for provider, cred := range store.Credentials { + status := "active" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + + fmt.Printf(" %s:\n", provider) + fmt.Printf(" Method: %s\n", cred.AuthMethod) + fmt.Printf(" Status: %s\n", status) + if cred.AccountID != "" { + fmt.Printf(" Account: %s\n", cred.AccountID) + } + if !cred.ExpiresAt.IsZero() { + fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) + } } } diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go new file mode 100644 index 0000000..94a79a6 --- /dev/null +++ b/pkg/auth/oauth.go @@ -0,0 +1,358 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os/exec" + "runtime" + "strings" + "time" +) + +type OAuthProviderConfig struct { + Issuer string + ClientID string + Scopes string + Port int +} + +func OpenAIOAuthConfig() OAuthProviderConfig { + return OAuthProviderConfig{ + Issuer: "https://auth.openai.com", + ClientID: "app_EMoamEEZ73f0CkXaXp7hrann", + Scopes: "openid profile email offline_access", + Port: 1455, + } +} + +func generateState() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { + pkce, err := GeneratePKCE() + if err != nil { + return nil, fmt.Errorf("generating PKCE: %w", err) + } + + state, err := generateState() + if err != nil { + return nil, fmt.Errorf("generating state: %w", err) + } + + redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port) + + authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI) + + resultCh := make(chan callbackResult, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + resultCh <- callbackResult{err: fmt.Errorf("state mismatch")} + http.Error(w, "State mismatch", http.StatusBadRequest) + return + } + + code := r.URL.Query().Get("code") + if code == "" { + errMsg := r.URL.Query().Get("error") + resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)} + http.Error(w, "No authorization code received", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, "

Authentication successful!

You can close this window.

") + resultCh <- callbackResult{code: code} + }) + + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port)) + if err != nil { + return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err) + } + + server := &http.Server{Handler: mux} + go server.Serve(listener) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + server.Shutdown(ctx) + }() + + if err := openBrowser(authURL); err != nil { + fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) + } + + fmt.Println("Waiting for authentication in browser...") + + select { + case result := <-resultCh: + if result.err != nil { + return nil, result.err + } + return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI) + case <-time.After(5 * time.Minute): + return nil, fmt.Errorf("authentication timed out after 5 minutes") + } +} + +type callbackResult struct { + code string + err error +} + +func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "client_id": cfg.ClientID, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/usercode", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, fmt.Errorf("requesting device code: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device code request failed: %s", string(body)) + } + + var deviceResp struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + Interval int `json:"interval"` + } + if err := json.Unmarshal(body, &deviceResp); err != nil { + return nil, fmt.Errorf("parsing device code response: %w", err) + } + + if deviceResp.Interval < 1 { + deviceResp.Interval = 5 + } + + fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n", + cfg.Issuer, deviceResp.UserCode) + + deadline := time.After(15 * time.Minute) + ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second) + defer ticker.Stop() + + for { + select { + case <-deadline: + return nil, fmt.Errorf("device code authentication timed out after 15 minutes") + case <-ticker.C: + cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode) + if err != nil { + continue + } + if cred != nil { + return cred, nil + } + } + } +} + +func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "device_auth_id": deviceAuthID, + "user_code": userCode, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/token", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("pending") + } + + body, _ := io.ReadAll(resp.Body) + + var tokenResp struct { + AuthorizationCode string `json:"authorization_code"` + CodeChallenge string `json:"code_challenge"` + CodeVerifier string `json:"code_verifier"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, err + } + + redirectURI := cfg.Issuer + "/deviceauth/callback" + return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI) +} + +func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) { + if cred.RefreshToken == "" { + return nil, fmt.Errorf("no refresh token available") + } + + data := url.Values{ + "client_id": {cfg.ClientID}, + "grant_type": {"refresh_token"}, + "refresh_token": {cred.RefreshToken}, + "scope": {"openid profile email"}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token refresh failed: %s", string(body)) + } + + return parseTokenResponse(body, cred.Provider) +} + +func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + return buildAuthorizeURL(cfg, pkce, state, redirectURI) +} + +func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + params := url.Values{ + "response_type": {"code"}, + "client_id": {cfg.ClientID}, + "redirect_uri": {redirectURI}, + "scope": {cfg.Scopes}, + "code_challenge": {pkce.CodeChallenge}, + "code_challenge_method": {"S256"}, + "state": {state}, + } + return cfg.Issuer + "/authorize?" + params.Encode() +} + +func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) { + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "client_id": {cfg.ClientID}, + "code_verifier": {codeVerifier}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("exchanging code for tokens: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange failed: %s", string(body)) + } + + return parseTokenResponse(body, "openai") +} + +func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + IDToken string `json:"id_token"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("parsing token response: %w", err) + } + + if tokenResp.AccessToken == "" { + return nil, fmt.Errorf("no access token in response") + } + + var expiresAt time.Time + if tokenResp.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + cred := &AuthCredential{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: expiresAt, + Provider: provider, + AuthMethod: "oauth", + } + + if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { + cred.AccountID = accountID + } + + return cred, nil +} + +func extractAccountID(accessToken string) string { + parts := strings.Split(accessToken, ".") + if len(parts) < 2 { + return "" + } + + payload := parts[1] + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64URLDecode(payload) + if err != nil { + return "" + } + + var claims map[string]interface{} + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + + if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { + if accountID, ok := authClaim["chatgpt_account_id"].(string); ok { + return accountID + } + } + + return "" +} + +func base64URLDecode(s string) ([]byte, error) { + s = strings.NewReplacer("-", "+", "_", "/").Replace(s) + return base64.StdEncoding.DecodeString(s) +} + +func openBrowser(url string) error { + switch runtime.GOOS { + case "darwin": + return exec.Command("open", url).Start() + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("cmd", "/c", "start", url).Start() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go new file mode 100644 index 0000000..00b4c60 --- /dev/null +++ b/pkg/auth/oauth_test.go @@ -0,0 +1,199 @@ +package auth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestBuildAuthorizeURL(t *testing.T) { + cfg := OAuthProviderConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client-id", + Scopes: "openid profile", + Port: 1455, + } + pkce := PKCECodes{ + CodeVerifier: "test-verifier", + CodeChallenge: "test-challenge", + } + + u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") + + if !strings.HasPrefix(u, "https://auth.example.com/authorize?") { + t.Errorf("URL does not start with expected prefix: %s", u) + } + if !strings.Contains(u, "client_id=test-client-id") { + t.Error("URL missing client_id") + } + if !strings.Contains(u, "code_challenge=test-challenge") { + t.Error("URL missing code_challenge") + } + if !strings.Contains(u, "code_challenge_method=S256") { + t.Error("URL missing code_challenge_method") + } + if !strings.Contains(u, "state=test-state") { + t.Error("URL missing state") + } + if !strings.Contains(u, "response_type=code") { + t.Error("URL missing response_type") + } +} + +func TestParseTokenResponse(t *testing.T) { + resp := map[string]interface{}{ + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": "test-id-token", + } + body, _ := json.Marshal(resp) + + cred, err := parseTokenResponse(body, "openai") + if err != nil { + t.Fatalf("parseTokenResponse() error: %v", err) + } + + if cred.AccessToken != "test-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token") + } + if cred.RefreshToken != "test-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token") + } + if cred.Provider != "openai" { + t.Errorf("Provider = %q, want %q", cred.Provider, "openai") + } + if cred.AuthMethod != "oauth" { + t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth") + } + if cred.ExpiresAt.IsZero() { + t.Error("ExpiresAt should not be zero") + } +} + +func TestParseTokenResponseNoAccessToken(t *testing.T) { + body := []byte(`{"refresh_token": "test"}`) + _, err := parseTokenResponse(body, "openai") + if err == nil { + t.Error("expected error for missing access_token") + } +} + +func TestExchangeCodeForTokens(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + Scopes: "openid", + Port: 1455, + } + + cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback") + if err != nil { + t.Fatalf("exchangeCodeForTokens() error: %v", err) + } + + if cred.AccessToken != "mock-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token") + } +} + +func TestRefreshAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "refresh_token" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "refreshed-access-token", + "refresh_token": "refreshed-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + } + + cred := &AuthCredential{ + AccessToken: "old-token", + RefreshToken: "old-refresh-token", + Provider: "openai", + AuthMethod: "oauth", + } + + refreshed, err := RefreshAccessToken(cred, cfg) + if err != nil { + t.Fatalf("RefreshAccessToken() error: %v", err) + } + + if refreshed.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token") + } + if refreshed.RefreshToken != "refreshed-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token") + } +} + +func TestRefreshAccessTokenNoRefreshToken(t *testing.T) { + cfg := OpenAIOAuthConfig() + cred := &AuthCredential{ + AccessToken: "old-token", + Provider: "openai", + AuthMethod: "oauth", + } + + _, err := RefreshAccessToken(cred, cfg) + if err == nil { + t.Error("expected error for missing refresh token") + } +} + +func TestOpenAIOAuthConfig(t *testing.T) { + cfg := OpenAIOAuthConfig() + if cfg.Issuer != "https://auth.openai.com" { + t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com") + } + if cfg.ClientID == "" { + t.Error("ClientID is empty") + } + if cfg.Port != 1455 { + t.Errorf("Port = %d, want 1455", cfg.Port) + } +} diff --git a/pkg/auth/pkce.go b/pkg/auth/pkce.go new file mode 100644 index 0000000..499daf8 --- /dev/null +++ b/pkg/auth/pkce.go @@ -0,0 +1,29 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" +) + +type PKCECodes struct { + CodeVerifier string + CodeChallenge string +} + +func GeneratePKCE() (PKCECodes, error) { + buf := make([]byte, 64) + if _, err := rand.Read(buf); err != nil { + return PKCECodes{}, err + } + + verifier := base64.RawURLEncoding.EncodeToString(buf) + + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(hash[:]) + + return PKCECodes{ + CodeVerifier: verifier, + CodeChallenge: challenge, + }, nil +} diff --git a/pkg/auth/pkce_test.go b/pkg/auth/pkce_test.go new file mode 100644 index 0000000..74ed573 --- /dev/null +++ b/pkg/auth/pkce_test.go @@ -0,0 +1,51 @@ +package auth + +import ( + "crypto/sha256" + "encoding/base64" + "testing" +) + +func TestGeneratePKCE(t *testing.T) { + codes, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes.CodeVerifier == "" { + t.Fatal("CodeVerifier is empty") + } + if codes.CodeChallenge == "" { + t.Fatal("CodeChallenge is empty") + } + + verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier) + if err != nil { + t.Fatalf("CodeVerifier is not valid base64url: %v", err) + } + if len(verifierBytes) != 64 { + t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes)) + } + + hash := sha256.Sum256([]byte(codes.CodeVerifier)) + expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) + if codes.CodeChallenge != expectedChallenge { + t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge) + } +} + +func TestGeneratePKCEUniqueness(t *testing.T) { + codes1, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + codes2, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes1.CodeVerifier == codes2.CodeVerifier { + t.Error("two GeneratePKCE() calls produced identical verifiers") + } +} diff --git a/pkg/auth/store.go b/pkg/auth/store.go new file mode 100644 index 0000000..2072492 --- /dev/null +++ b/pkg/auth/store.go @@ -0,0 +1,112 @@ +package auth + +import ( + "encoding/json" + "os" + "path/filepath" + "time" +) + +type AuthCredential struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + AccountID string `json:"account_id,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + Provider string `json:"provider"` + AuthMethod string `json:"auth_method"` +} + +type AuthStore struct { + Credentials map[string]*AuthCredential `json:"credentials"` +} + +func (c *AuthCredential) IsExpired() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().After(c.ExpiresAt) +} + +func (c *AuthCredential) NeedsRefresh() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) +} + +func authFilePath() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, ".picoclaw", "auth.json") +} + +func LoadStore() (*AuthStore, error) { + path := authFilePath() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil + } + return nil, err + } + + var store AuthStore + if err := json.Unmarshal(data, &store); err != nil { + return nil, err + } + if store.Credentials == nil { + store.Credentials = make(map[string]*AuthCredential) + } + return &store, nil +} + +func SaveStore(store *AuthStore) error { + path := authFilePath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + data, err := json.MarshalIndent(store, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0600) +} + +func GetCredential(provider string) (*AuthCredential, error) { + store, err := LoadStore() + if err != nil { + return nil, err + } + cred, ok := store.Credentials[provider] + if !ok { + return nil, nil + } + return cred, nil +} + +func SetCredential(provider string, cred *AuthCredential) error { + store, err := LoadStore() + if err != nil { + return err + } + store.Credentials[provider] = cred + return SaveStore(store) +} + +func DeleteCredential(provider string) error { + store, err := LoadStore() + if err != nil { + return err + } + delete(store.Credentials, provider) + return SaveStore(store) +} + +func DeleteAllCredentials() error { + path := authFilePath() + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/pkg/auth/store_test.go b/pkg/auth/store_test.go new file mode 100644 index 0000000..d96b460 --- /dev/null +++ b/pkg/auth/store_test.go @@ -0,0 +1,189 @@ +package auth + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestAuthCredentialIsExpired(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"future", time.Now().Add(time.Hour), false}, + {"past", time.Now().Add(-time.Hour), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.IsExpired(); got != tt.want { + t.Errorf("IsExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthCredentialNeedsRefresh(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"far future", time.Now().Add(time.Hour), false}, + {"within 5 min", time.Now().Add(3 * time.Minute), true}, + {"already expired", time.Now().Add(-time.Minute), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.NeedsRefresh(); got != tt.want { + t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStoreRoundtrip(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + AccountID: "acct-123", + ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second), + Provider: "openai", + AuthMethod: "oauth", + } + + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded == nil { + t.Fatal("GetCredential() returned nil") + } + if loaded.AccessToken != cred.AccessToken { + t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken) + } + if loaded.RefreshToken != cred.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken) + } + if loaded.Provider != cred.Provider { + t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider) + } +} + +func TestStoreFilePermissions(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "secret-token", + Provider: "openai", + AuthMethod: "oauth", + } + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + perm := info.Mode().Perm() + if perm != 0600 { + t.Errorf("file permissions = %o, want 0600", perm) + } +} + +func TestStoreMultiProvider(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"} + anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"} + + if err := SetCredential("openai", openaiCred); err != nil { + t.Fatalf("SetCredential(openai) error: %v", err) + } + if err := SetCredential("anthropic", anthropicCred); err != nil { + t.Fatalf("SetCredential(anthropic) error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential(openai) error: %v", err) + } + if loaded.AccessToken != "openai-token" { + t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token") + } + + loaded, err = GetCredential("anthropic") + if err != nil { + t.Fatalf("GetCredential(anthropic) error: %v", err) + } + if loaded.AccessToken != "anthropic-token" { + t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token") + } +} + +func TestDeleteCredential(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"} + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + if err := DeleteCredential("openai"); err != nil { + t.Fatalf("DeleteCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded != nil { + t.Error("expected nil after delete") + } +} + +func TestLoadStoreEmpty(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + store, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if store == nil { + t.Fatal("LoadStore() returned nil") + } + if len(store.Credentials) != 0 { + t.Errorf("expected empty credentials, got %d", len(store.Credentials)) + } +} diff --git a/pkg/auth/token.go b/pkg/auth/token.go new file mode 100644 index 0000000..a5a13ff --- /dev/null +++ b/pkg/auth/token.go @@ -0,0 +1,43 @@ +package auth + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) { + fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider)) + fmt.Print("> ") + + scanner := bufio.NewScanner(r) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading token: %w", err) + } + return nil, fmt.Errorf("no input received") + } + + token := strings.TrimSpace(scanner.Text()) + if token == "" { + return nil, fmt.Errorf("token cannot be empty") + } + + return &AuthCredential{ + AccessToken: token, + Provider: provider, + AuthMethod: "token", + }, nil +} + +func providerDisplayName(provider string) string { + switch provider { + case "anthropic": + return "console.anthropic.com" + case "openai": + return "platform.openai.com" + default: + return provider + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 5b9c2b5..7fc6253 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -99,8 +99,9 @@ type ProvidersConfig struct { } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` } type GatewayConfig struct { diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 12909df..dab6132 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -15,13 +15,16 @@ import ( "net/http" "strings" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) type HTTPProvider struct { - apiKey string - apiBase string - httpClient *http.Client + apiKey string + apiBase string + httpClient *http.Client + tokenSource func() (string, error) + accountID string } func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider { @@ -73,9 +76,17 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - authHeader := "Bearer " + p.apiKey - req.Header.Set("Authorization", authHeader) + if p.tokenSource != nil { + token, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("failed to get auth token: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + if p.accountID != "" { + req.Header.Set("Chatgpt-Account-Id", p.accountID) + } + } else if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) } resp, err := p.httpClient.Do(req) @@ -170,6 +181,47 @@ func (p *HTTPProvider) GetDefaultModel() string { return "" } +func createOAuthTokenSource(provider string) func() (string, error) { + return func() (string, error) { + cred, err := auth.GetCredential(provider) + if err != nil { + return "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", fmt.Errorf("no OAuth credentials for %s. Run: picoclaw auth login --provider %s", provider, provider) + } + + if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.OpenAIOAuthConfig() + refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) + if err != nil { + return "", fmt.Errorf("refreshing token: %w", err) + } + if err := auth.SetCredential(provider, refreshed); err != nil { + return "", fmt.Errorf("saving refreshed token: %w", err) + } + return refreshed.AccessToken, nil + } + + return cred.AccessToken, nil + } +} + +func createAuthProvider(providerName string, apiBase string) (LLMProvider, error) { + cred, err := auth.GetCredential(providerName) + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for %s. Run: picoclaw auth login --provider %s", providerName, providerName) + } + + p := NewHTTPProvider(cred.AccessToken, apiBase) + p.tokenSource = createOAuthTokenSource(providerName) + p.accountID = cred.AccountID + return p, nil +} + func CreateProvider(cfg *config.Config) (LLMProvider, error) { model := cfg.Agents.Defaults.Model @@ -186,14 +238,28 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiBase = "https://openrouter.ai/api/v1" } - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && cfg.Providers.Anthropic.APIKey != "": + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + ab := cfg.Providers.Anthropic.APIBase + if ab == "" { + ab = "https://api.anthropic.com/v1" + } + return createAuthProvider("anthropic", ab) + } apiKey = cfg.Providers.Anthropic.APIKey apiBase = cfg.Providers.Anthropic.APIBase if apiBase == "" { apiBase = "https://api.anthropic.com/v1" } - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && cfg.Providers.OpenAI.APIKey != "": + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + ab := cfg.Providers.OpenAI.APIBase + if ab == "" { + ab = "https://api.openai.com/v1" + } + return createAuthProvider("openai", ab) + } apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase if apiBase == "" {