diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go index 94a79a6..ecd9ba2 100644 --- a/pkg/auth/oauth.go +++ b/pkg/auth/oauth.go @@ -13,6 +13,7 @@ import ( "net/url" "os/exec" "runtime" + "strconv" "strings" "time" ) @@ -92,10 +93,13 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { server.Shutdown(ctx) }() + fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL) + if err := openBrowser(authURL); err != nil { fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) } + fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code") fmt.Println("Waiting for authentication in browser...") select { @@ -114,6 +118,57 @@ type callbackResult struct { err error } +type deviceCodeResponse struct { + DeviceAuthID string + UserCode string + Interval int +} + +func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) { + var raw struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + Interval json.RawMessage `json:"interval"` + } + + if err := json.Unmarshal(body, &raw); err != nil { + return deviceCodeResponse{}, err + } + + interval, err := parseFlexibleInt(raw.Interval) + if err != nil { + return deviceCodeResponse{}, err + } + + return deviceCodeResponse{ + DeviceAuthID: raw.DeviceAuthID, + UserCode: raw.UserCode, + Interval: interval, + }, nil +} + +func parseFlexibleInt(raw json.RawMessage) (int, error) { + if len(raw) == 0 || string(raw) == "null" { + return 0, nil + } + + var interval int + if err := json.Unmarshal(raw, &interval); err == nil { + return interval, nil + } + + var intervalStr string + if err := json.Unmarshal(raw, &intervalStr); err == nil { + intervalStr = strings.TrimSpace(intervalStr) + if intervalStr == "" { + return 0, nil + } + return strconv.Atoi(intervalStr) + } + + return 0, fmt.Errorf("invalid integer value: %s", string(raw)) +} + func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { reqBody, _ := json.Marshal(map[string]string{ "client_id": cfg.ClientID, @@ -134,12 +189,8 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { return nil, fmt.Errorf("device code request failed: %s", string(body)) } - var deviceResp struct { - DeviceAuthID string `json:"device_auth_id"` - UserCode string `json:"user_code"` - Interval int `json:"interval"` - } - if err := json.Unmarshal(body, &deviceResp); err != nil { + deviceResp, err := parseDeviceCodeResponse(body) + if err != nil { return nil, fmt.Errorf("parsing device code response: %w", err) } diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go index 00b4c60..9f80132 100644 --- a/pkg/auth/oauth_test.go +++ b/pkg/auth/oauth_test.go @@ -197,3 +197,43 @@ func TestOpenAIOAuthConfig(t *testing.T) { t.Errorf("Port = %d, want 1455", cfg.Port) } } + +func TestParseDeviceCodeResponseIntervalAsNumber(t *testing.T) { + body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":5}`) + + resp, err := parseDeviceCodeResponse(body) + if err != nil { + t.Fatalf("parseDeviceCodeResponse() error: %v", err) + } + + if resp.DeviceAuthID != "abc" { + t.Errorf("DeviceAuthID = %q, want %q", resp.DeviceAuthID, "abc") + } + if resp.UserCode != "DEF-1234" { + t.Errorf("UserCode = %q, want %q", resp.UserCode, "DEF-1234") + } + if resp.Interval != 5 { + t.Errorf("Interval = %d, want %d", resp.Interval, 5) + } +} + +func TestParseDeviceCodeResponseIntervalAsString(t *testing.T) { + body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"5"}`) + + resp, err := parseDeviceCodeResponse(body) + if err != nil { + t.Fatalf("parseDeviceCodeResponse() error: %v", err) + } + + if resp.Interval != 5 { + t.Errorf("Interval = %d, want %d", resp.Interval, 5) + } +} + +func TestParseDeviceCodeResponseInvalidInterval(t *testing.T) { + body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"abc"}`) + + if _, err := parseDeviceCodeResponse(body); err == nil { + t.Fatal("expected error for invalid interval") + } +}