feat(auth): add OAuth and token-based login for OpenAI and Anthropic

Add `picoclaw auth` CLI command supporting:
- OpenAI OAuth2 (PKCE + browser callback or device code flow)
- Anthropic paste-token flow
- Token storage at ~/.picoclaw/auth.json with 0600 permissions
- Auto-refresh for expired OAuth tokens in provider

Closes #18

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Cory LaNou
2026-02-11 11:41:13 -06:00
parent 6ccd9d0a99
commit 5efe8a2020
10 changed files with 1295 additions and 10 deletions

View File

@@ -19,6 +19,7 @@ import (
"github.com/chzyer/readline"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
@@ -85,6 +86,8 @@ func main() {
gatewayCmd()
case "status":
statusCmd()
case "auth":
authCmd()
case "cron":
cronCmd()
case "skills":
@@ -152,6 +155,7 @@ func printHelp() {
fmt.Println("Commands:")
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
fmt.Println(" agent Interact with the agent directly")
fmt.Println(" auth Manage authentication (login, logout, status)")
fmt.Println(" gateway Start picoclaw gateway")
fmt.Println(" status Show picoclaw status")
fmt.Println(" cron Manage scheduled tasks")
@@ -682,6 +686,239 @@ func statusCmd() {
} else {
fmt.Println("vLLM/Local: not set")
}
store, _ := auth.LoadStore()
if store != nil && len(store.Credentials) > 0 {
fmt.Println("\nOAuth/Token Auth:")
for provider, cred := range store.Credentials {
status := "authenticated"
if cred.IsExpired() {
status = "expired"
} else if cred.NeedsRefresh() {
status = "needs refresh"
}
fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status)
}
}
}
}
func authCmd() {
if len(os.Args) < 3 {
authHelp()
return
}
switch os.Args[2] {
case "login":
authLoginCmd()
case "logout":
authLogoutCmd()
case "status":
authStatusCmd()
default:
fmt.Printf("Unknown auth command: %s\n", os.Args[2])
authHelp()
}
}
func authHelp() {
fmt.Println("\nAuth commands:")
fmt.Println(" login Login via OAuth or paste token")
fmt.Println(" logout Remove stored credentials")
fmt.Println(" status Show current auth status")
fmt.Println()
fmt.Println("Login options:")
fmt.Println(" --provider <name> Provider to login with (openai, anthropic)")
fmt.Println(" --device-code Use device code flow (for headless environments)")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" picoclaw auth login --provider openai")
fmt.Println(" picoclaw auth login --provider openai --device-code")
fmt.Println(" picoclaw auth login --provider anthropic")
fmt.Println(" picoclaw auth logout --provider openai")
fmt.Println(" picoclaw auth status")
}
func authLoginCmd() {
provider := ""
useDeviceCode := false
args := os.Args[3:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--provider", "-p":
if i+1 < len(args) {
provider = args[i+1]
i++
}
case "--device-code":
useDeviceCode = true
}
}
if provider == "" {
fmt.Println("Error: --provider is required")
fmt.Println("Supported providers: openai, anthropic")
return
}
switch provider {
case "openai":
authLoginOpenAI(useDeviceCode)
case "anthropic":
authLoginPasteToken(provider)
default:
fmt.Printf("Unsupported provider: %s\n", provider)
fmt.Println("Supported providers: openai, anthropic")
}
}
func authLoginOpenAI(useDeviceCode bool) {
cfg := auth.OpenAIOAuthConfig()
var cred *auth.AuthCredential
var err error
if useDeviceCode {
cred, err = auth.LoginDeviceCode(cfg)
} else {
cred, err = auth.LoginBrowser(cfg)
}
if err != nil {
fmt.Printf("Login failed: %v\n", err)
os.Exit(1)
}
if err := auth.SetCredential("openai", cred); err != nil {
fmt.Printf("Failed to save credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
appCfg.Providers.OpenAI.AuthMethod = "oauth"
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
fmt.Println("Login successful!")
if cred.AccountID != "" {
fmt.Printf("Account: %s\n", cred.AccountID)
}
}
func authLoginPasteToken(provider string) {
cred, err := auth.LoginPasteToken(provider, os.Stdin)
if err != nil {
fmt.Printf("Login failed: %v\n", err)
os.Exit(1)
}
if err := auth.SetCredential(provider, cred); err != nil {
fmt.Printf("Failed to save credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
switch provider {
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = "token"
case "openai":
appCfg.Providers.OpenAI.AuthMethod = "token"
}
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
fmt.Printf("Token saved for %s!\n", provider)
}
func authLogoutCmd() {
provider := ""
args := os.Args[3:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--provider", "-p":
if i+1 < len(args) {
provider = args[i+1]
i++
}
}
}
if provider != "" {
if err := auth.DeleteCredential(provider); err != nil {
fmt.Printf("Failed to remove credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
switch provider {
case "openai":
appCfg.Providers.OpenAI.AuthMethod = ""
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = ""
}
config.SaveConfig(getConfigPath(), appCfg)
}
fmt.Printf("Logged out from %s\n", provider)
} else {
if err := auth.DeleteAllCredentials(); err != nil {
fmt.Printf("Failed to remove credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
appCfg.Providers.OpenAI.AuthMethod = ""
appCfg.Providers.Anthropic.AuthMethod = ""
config.SaveConfig(getConfigPath(), appCfg)
}
fmt.Println("Logged out from all providers")
}
}
func authStatusCmd() {
store, err := auth.LoadStore()
if err != nil {
fmt.Printf("Error loading auth store: %v\n", err)
return
}
if len(store.Credentials) == 0 {
fmt.Println("No authenticated providers.")
fmt.Println("Run: picoclaw auth login --provider <name>")
return
}
fmt.Println("\nAuthenticated Providers:")
fmt.Println("------------------------")
for provider, cred := range store.Credentials {
status := "active"
if cred.IsExpired() {
status = "expired"
} else if cred.NeedsRefresh() {
status = "needs refresh"
}
fmt.Printf(" %s:\n", provider)
fmt.Printf(" Method: %s\n", cred.AuthMethod)
fmt.Printf(" Status: %s\n", status)
if cred.AccountID != "" {
fmt.Printf(" Account: %s\n", cred.AccountID)
}
if !cred.ExpiresAt.IsZero() {
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
}
}
}

358
pkg/auth/oauth.go Normal file
View File

@@ -0,0 +1,358 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os/exec"
"runtime"
"strings"
"time"
)
type OAuthProviderConfig struct {
Issuer string
ClientID string
Scopes string
Port int
}
func OpenAIOAuthConfig() OAuthProviderConfig {
return OAuthProviderConfig{
Issuer: "https://auth.openai.com",
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
Scopes: "openid profile email offline_access",
Port: 1455,
}
}
func generateState() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
pkce, err := GeneratePKCE()
if err != nil {
return nil, fmt.Errorf("generating PKCE: %w", err)
}
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux()
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("state") != state {
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
http.Error(w, "State mismatch", http.StatusBadRequest)
return
}
code := r.URL.Query().Get("code")
if code == "" {
errMsg := r.URL.Query().Get("error")
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
http.Error(w, "No authorization code received", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
resultCh <- callbackResult{code: code}
})
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
if err != nil {
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
}
server := &http.Server{Handler: mux}
go server.Serve(listener)
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
server.Shutdown(ctx)
}()
if err := openBrowser(authURL); err != nil {
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
}
fmt.Println("Waiting for authentication in browser...")
select {
case result := <-resultCh:
if result.err != nil {
return nil, result.err
}
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("authentication timed out after 5 minutes")
}
}
type callbackResult struct {
code string
err error
}
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"client_id": cfg.ClientID,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/usercode",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, fmt.Errorf("requesting device code: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
var deviceResp struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
Interval int `json:"interval"`
}
if err := json.Unmarshal(body, &deviceResp); err != nil {
return nil, fmt.Errorf("parsing device code response: %w", err)
}
if deviceResp.Interval < 1 {
deviceResp.Interval = 5
}
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
cfg.Issuer, deviceResp.UserCode)
deadline := time.After(15 * time.Minute)
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-deadline:
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
case <-ticker.C:
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
if err != nil {
continue
}
if cred != nil {
return cred, nil
}
}
}
}
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"device_auth_id": deviceAuthID,
"user_code": userCode,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/token",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("pending")
}
body, _ := io.ReadAll(resp.Body)
var tokenResp struct {
AuthorizationCode string `json:"authorization_code"`
CodeChallenge string `json:"code_challenge"`
CodeVerifier string `json:"code_verifier"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, err
}
redirectURI := cfg.Issuer + "/deviceauth/callback"
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
}
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
if cred.RefreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
}
data := url.Values{
"client_id": {cfg.ClientID},
"grant_type": {"refresh_token"},
"refresh_token": {cred.RefreshToken},
"scope": {"openid profile email"},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
return parseTokenResponse(body, cred.Provider)
}
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
}
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
params := url.Values{
"response_type": {"code"},
"client_id": {cfg.ClientID},
"redirect_uri": {redirectURI},
"scope": {cfg.Scopes},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
}
return cfg.Issuer + "/authorize?" + params.Encode()
}
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {cfg.ClientID},
"code_verifier": {codeVerifier},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed: %s", string(body))
}
return parseTokenResponse(body, "openai")
}
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parsing token response: %w", err)
}
if tokenResp.AccessToken == "" {
return nil, fmt.Errorf("no access token in response")
}
var expiresAt time.Time
if tokenResp.ExpiresIn > 0 {
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
}
cred := &AuthCredential{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresAt: expiresAt,
Provider: provider,
AuthMethod: "oauth",
}
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
cred.AccountID = accountID
}
return cred, nil
}
func extractAccountID(accessToken string) string {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
return ""
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64URLDecode(payload)
if err != nil {
return ""
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
return accountID
}
}
return ""
}
func base64URLDecode(s string) ([]byte, error) {
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
return base64.StdEncoding.DecodeString(s)
}
func openBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("cmd", "/c", "start", url).Start()
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}

199
pkg/auth/oauth_test.go Normal file
View File

@@ -0,0 +1,199 @@
package auth
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestBuildAuthorizeURL(t *testing.T) {
cfg := OAuthProviderConfig{
Issuer: "https://auth.example.com",
ClientID: "test-client-id",
Scopes: "openid profile",
Port: 1455,
}
pkce := PKCECodes{
CodeVerifier: "test-verifier",
CodeChallenge: "test-challenge",
}
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
t.Errorf("URL does not start with expected prefix: %s", u)
}
if !strings.Contains(u, "client_id=test-client-id") {
t.Error("URL missing client_id")
}
if !strings.Contains(u, "code_challenge=test-challenge") {
t.Error("URL missing code_challenge")
}
if !strings.Contains(u, "code_challenge_method=S256") {
t.Error("URL missing code_challenge_method")
}
if !strings.Contains(u, "state=test-state") {
t.Error("URL missing state")
}
if !strings.Contains(u, "response_type=code") {
t.Error("URL missing response_type")
}
}
func TestParseTokenResponse(t *testing.T) {
resp := map[string]interface{}{
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": "test-id-token",
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccessToken != "test-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token")
}
if cred.RefreshToken != "test-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token")
}
if cred.Provider != "openai" {
t.Errorf("Provider = %q, want %q", cred.Provider, "openai")
}
if cred.AuthMethod != "oauth" {
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
}
if cred.ExpiresAt.IsZero() {
t.Error("ExpiresAt should not be zero")
}
}
func TestParseTokenResponseNoAccessToken(t *testing.T) {
body := []byte(`{"refresh_token": "test"}`)
_, err := parseTokenResponse(body, "openai")
if err == nil {
t.Error("expected error for missing access_token")
}
}
func TestExchangeCodeForTokens(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "authorization_code" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
Scopes: "openid",
Port: 1455,
}
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
if err != nil {
t.Fatalf("exchangeCodeForTokens() error: %v", err)
}
if cred.AccessToken != "mock-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token")
}
}
func TestRefreshAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "refresh_token" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "refreshed-access-token",
"refresh_token": "refreshed-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
}
cred := &AuthCredential{
AccessToken: "old-token",
RefreshToken: "old-refresh-token",
Provider: "openai",
AuthMethod: "oauth",
}
refreshed, err := RefreshAccessToken(cred, cfg)
if err != nil {
t.Fatalf("RefreshAccessToken() error: %v", err)
}
if refreshed.AccessToken != "refreshed-access-token" {
t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token")
}
if refreshed.RefreshToken != "refreshed-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token")
}
}
func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
cfg := OpenAIOAuthConfig()
cred := &AuthCredential{
AccessToken: "old-token",
Provider: "openai",
AuthMethod: "oauth",
}
_, err := RefreshAccessToken(cred, cfg)
if err == nil {
t.Error("expected error for missing refresh token")
}
}
func TestOpenAIOAuthConfig(t *testing.T) {
cfg := OpenAIOAuthConfig()
if cfg.Issuer != "https://auth.openai.com" {
t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com")
}
if cfg.ClientID == "" {
t.Error("ClientID is empty")
}
if cfg.Port != 1455 {
t.Errorf("Port = %d, want 1455", cfg.Port)
}
}

29
pkg/auth/pkce.go Normal file
View File

@@ -0,0 +1,29 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
)
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
func GeneratePKCE() (PKCECodes, error) {
buf := make([]byte, 64)
if _, err := rand.Read(buf); err != nil {
return PKCECodes{}, err
}
verifier := base64.RawURLEncoding.EncodeToString(buf)
hash := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
return PKCECodes{
CodeVerifier: verifier,
CodeChallenge: challenge,
}, nil
}

51
pkg/auth/pkce_test.go Normal file
View File

@@ -0,0 +1,51 @@
package auth
import (
"crypto/sha256"
"encoding/base64"
"testing"
)
func TestGeneratePKCE(t *testing.T) {
codes, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes.CodeVerifier == "" {
t.Fatal("CodeVerifier is empty")
}
if codes.CodeChallenge == "" {
t.Fatal("CodeChallenge is empty")
}
verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier)
if err != nil {
t.Fatalf("CodeVerifier is not valid base64url: %v", err)
}
if len(verifierBytes) != 64 {
t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes))
}
hash := sha256.Sum256([]byte(codes.CodeVerifier))
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
if codes.CodeChallenge != expectedChallenge {
t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge)
}
}
func TestGeneratePKCEUniqueness(t *testing.T) {
codes1, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
codes2, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes1.CodeVerifier == codes2.CodeVerifier {
t.Error("two GeneratePKCE() calls produced identical verifiers")
}
}

112
pkg/auth/store.go Normal file
View File

@@ -0,0 +1,112 @@
package auth
import (
"encoding/json"
"os"
"path/filepath"
"time"
)
type AuthCredential struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
AccountID string `json:"account_id,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
Provider string `json:"provider"`
AuthMethod string `json:"auth_method"`
}
type AuthStore struct {
Credentials map[string]*AuthCredential `json:"credentials"`
}
func (c *AuthCredential) IsExpired() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().After(c.ExpiresAt)
}
func (c *AuthCredential) NeedsRefresh() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().Add(5 * time.Minute).After(c.ExpiresAt)
}
func authFilePath() string {
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "auth.json")
}
func LoadStore() (*AuthStore, error) {
path := authFilePath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil
}
return nil, err
}
var store AuthStore
if err := json.Unmarshal(data, &store); err != nil {
return nil, err
}
if store.Credentials == nil {
store.Credentials = make(map[string]*AuthCredential)
}
return &store, nil
}
func SaveStore(store *AuthStore) error {
path := authFilePath()
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
data, err := json.MarshalIndent(store, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
func GetCredential(provider string) (*AuthCredential, error) {
store, err := LoadStore()
if err != nil {
return nil, err
}
cred, ok := store.Credentials[provider]
if !ok {
return nil, nil
}
return cred, nil
}
func SetCredential(provider string, cred *AuthCredential) error {
store, err := LoadStore()
if err != nil {
return err
}
store.Credentials[provider] = cred
return SaveStore(store)
}
func DeleteCredential(provider string) error {
store, err := LoadStore()
if err != nil {
return err
}
delete(store.Credentials, provider)
return SaveStore(store)
}
func DeleteAllCredentials() error {
path := authFilePath()
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}

189
pkg/auth/store_test.go Normal file
View File

@@ -0,0 +1,189 @@
package auth
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestAuthCredentialIsExpired(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"future", time.Now().Add(time.Hour), false},
{"past", time.Now().Add(-time.Hour), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.IsExpired(); got != tt.want {
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthCredentialNeedsRefresh(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"far future", time.Now().Add(time.Hour), false},
{"within 5 min", time.Now().Add(3 * time.Minute), true},
{"already expired", time.Now().Add(-time.Minute), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.NeedsRefresh(); got != tt.want {
t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want)
}
})
}
}
func TestStoreRoundtrip(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
AccountID: "acct-123",
ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second),
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded == nil {
t.Fatal("GetCredential() returned nil")
}
if loaded.AccessToken != cred.AccessToken {
t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken)
}
if loaded.RefreshToken != cred.RefreshToken {
t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken)
}
if loaded.Provider != cred.Provider {
t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider)
}
}
func TestStoreFilePermissions(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "secret-token",
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
path := filepath.Join(tmpDir, ".picoclaw", "auth.json")
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat() error: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("file permissions = %o, want 0600", perm)
}
}
func TestStoreMultiProvider(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"}
anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"}
if err := SetCredential("openai", openaiCred); err != nil {
t.Fatalf("SetCredential(openai) error: %v", err)
}
if err := SetCredential("anthropic", anthropicCred); err != nil {
t.Fatalf("SetCredential(anthropic) error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential(openai) error: %v", err)
}
if loaded.AccessToken != "openai-token" {
t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token")
}
loaded, err = GetCredential("anthropic")
if err != nil {
t.Fatalf("GetCredential(anthropic) error: %v", err)
}
if loaded.AccessToken != "anthropic-token" {
t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token")
}
}
func TestDeleteCredential(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
if err := DeleteCredential("openai"); err != nil {
t.Fatalf("DeleteCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded != nil {
t.Error("expected nil after delete")
}
}
func TestLoadStoreEmpty(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
store, err := LoadStore()
if err != nil {
t.Fatalf("LoadStore() error: %v", err)
}
if store == nil {
t.Fatal("LoadStore() returned nil")
}
if len(store.Credentials) != 0 {
t.Errorf("expected empty credentials, got %d", len(store.Credentials))
}
}

43
pkg/auth/token.go Normal file
View File

@@ -0,0 +1,43 @@
package auth
import (
"bufio"
"fmt"
"io"
"strings"
)
func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider))
fmt.Print("> ")
scanner := bufio.NewScanner(r)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading token: %w", err)
}
return nil, fmt.Errorf("no input received")
}
token := strings.TrimSpace(scanner.Text())
if token == "" {
return nil, fmt.Errorf("token cannot be empty")
}
return &AuthCredential{
AccessToken: token,
Provider: provider,
AuthMethod: "token",
}, nil
}
func providerDisplayName(provider string) string {
switch provider {
case "anthropic":
return "console.anthropic.com"
case "openai":
return "platform.openai.com"
default:
return provider
}
}

View File

@@ -101,6 +101,7 @@ type ProvidersConfig struct {
type ProviderConfig struct {
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
}
type GatewayConfig struct {

View File

@@ -15,6 +15,7 @@ import (
"net/http"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -22,6 +23,8 @@ type HTTPProvider struct {
apiKey string
apiBase string
httpClient *http.Client
tokenSource func() (string, error)
accountID string
}
func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider {
@@ -73,9 +76,17 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
authHeader := "Bearer " + p.apiKey
req.Header.Set("Authorization", authHeader)
if p.tokenSource != nil {
token, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("failed to get auth token: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
if p.accountID != "" {
req.Header.Set("Chatgpt-Account-Id", p.accountID)
}
} else if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
@@ -170,6 +181,47 @@ func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func createOAuthTokenSource(provider string) func() (string, error) {
return func() (string, error) {
cred, err := auth.GetCredential(provider)
if err != nil {
return "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return "", fmt.Errorf("no OAuth credentials for %s. Run: picoclaw auth login --provider %s", provider, provider)
}
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
oauthCfg := auth.OpenAIOAuthConfig()
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
if err != nil {
return "", fmt.Errorf("refreshing token: %w", err)
}
if err := auth.SetCredential(provider, refreshed); err != nil {
return "", fmt.Errorf("saving refreshed token: %w", err)
}
return refreshed.AccessToken, nil
}
return cred.AccessToken, nil
}
}
func createAuthProvider(providerName string, apiBase string) (LLMProvider, error) {
cred, err := auth.GetCredential(providerName)
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for %s. Run: picoclaw auth login --provider %s", providerName, providerName)
}
p := NewHTTPProvider(cred.AccessToken, apiBase)
p.tokenSource = createOAuthTokenSource(providerName)
p.accountID = cred.AccountID
return p, nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
model := cfg.Agents.Defaults.Model
@@ -186,14 +238,28 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && cfg.Providers.Anthropic.APIKey != "":
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
ab := cfg.Providers.Anthropic.APIBase
if ab == "" {
ab = "https://api.anthropic.com/v1"
}
return createAuthProvider("anthropic", ab)
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && cfg.Providers.OpenAI.APIKey != "":
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
ab := cfg.Providers.OpenAI.APIBase
if ab == "" {
ab = "https://api.openai.com/v1"
}
return createAuthProvider("openai", ab)
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {