Merge pull request #32 from corylanou/issue-18-add-support-for-openai-anthropic-oauth-based-login
feat(auth): add OAuth login with SDK-based subscription providers
This commit is contained in:
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/chzyer/readline"
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -85,6 +86,8 @@ func main() {
|
||||
gatewayCmd()
|
||||
case "status":
|
||||
statusCmd()
|
||||
case "auth":
|
||||
authCmd()
|
||||
case "cron":
|
||||
cronCmd()
|
||||
case "skills":
|
||||
@@ -152,6 +155,7 @@ func printHelp() {
|
||||
fmt.Println("Commands:")
|
||||
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
|
||||
fmt.Println(" agent Interact with the agent directly")
|
||||
fmt.Println(" auth Manage authentication (login, logout, status)")
|
||||
fmt.Println(" gateway Start picoclaw gateway")
|
||||
fmt.Println(" status Show picoclaw status")
|
||||
fmt.Println(" cron Manage scheduled tasks")
|
||||
@@ -682,6 +686,239 @@ func statusCmd() {
|
||||
} else {
|
||||
fmt.Println("vLLM/Local: not set")
|
||||
}
|
||||
|
||||
store, _ := auth.LoadStore()
|
||||
if store != nil && len(store.Credentials) > 0 {
|
||||
fmt.Println("\nOAuth/Token Auth:")
|
||||
for provider, cred := range store.Credentials {
|
||||
status := "authenticated"
|
||||
if cred.IsExpired() {
|
||||
status = "expired"
|
||||
} else if cred.NeedsRefresh() {
|
||||
status = "needs refresh"
|
||||
}
|
||||
fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func authCmd() {
|
||||
if len(os.Args) < 3 {
|
||||
authHelp()
|
||||
return
|
||||
}
|
||||
|
||||
switch os.Args[2] {
|
||||
case "login":
|
||||
authLoginCmd()
|
||||
case "logout":
|
||||
authLogoutCmd()
|
||||
case "status":
|
||||
authStatusCmd()
|
||||
default:
|
||||
fmt.Printf("Unknown auth command: %s\n", os.Args[2])
|
||||
authHelp()
|
||||
}
|
||||
}
|
||||
|
||||
func authHelp() {
|
||||
fmt.Println("\nAuth commands:")
|
||||
fmt.Println(" login Login via OAuth or paste token")
|
||||
fmt.Println(" logout Remove stored credentials")
|
||||
fmt.Println(" status Show current auth status")
|
||||
fmt.Println()
|
||||
fmt.Println("Login options:")
|
||||
fmt.Println(" --provider <name> Provider to login with (openai, anthropic)")
|
||||
fmt.Println(" --device-code Use device code flow (for headless environments)")
|
||||
fmt.Println()
|
||||
fmt.Println("Examples:")
|
||||
fmt.Println(" picoclaw auth login --provider openai")
|
||||
fmt.Println(" picoclaw auth login --provider openai --device-code")
|
||||
fmt.Println(" picoclaw auth login --provider anthropic")
|
||||
fmt.Println(" picoclaw auth logout --provider openai")
|
||||
fmt.Println(" picoclaw auth status")
|
||||
}
|
||||
|
||||
func authLoginCmd() {
|
||||
provider := ""
|
||||
useDeviceCode := false
|
||||
|
||||
args := os.Args[3:]
|
||||
for i := 0; i < len(args); i++ {
|
||||
switch args[i] {
|
||||
case "--provider", "-p":
|
||||
if i+1 < len(args) {
|
||||
provider = args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--device-code":
|
||||
useDeviceCode = true
|
||||
}
|
||||
}
|
||||
|
||||
if provider == "" {
|
||||
fmt.Println("Error: --provider is required")
|
||||
fmt.Println("Supported providers: openai, anthropic")
|
||||
return
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case "openai":
|
||||
authLoginOpenAI(useDeviceCode)
|
||||
case "anthropic":
|
||||
authLoginPasteToken(provider)
|
||||
default:
|
||||
fmt.Printf("Unsupported provider: %s\n", provider)
|
||||
fmt.Println("Supported providers: openai, anthropic")
|
||||
}
|
||||
}
|
||||
|
||||
func authLoginOpenAI(useDeviceCode bool) {
|
||||
cfg := auth.OpenAIOAuthConfig()
|
||||
|
||||
var cred *auth.AuthCredential
|
||||
var err error
|
||||
|
||||
if useDeviceCode {
|
||||
cred, err = auth.LoginDeviceCode(cfg)
|
||||
} else {
|
||||
cred, err = auth.LoginBrowser(cfg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := auth.SetCredential("openai", cred); err != nil {
|
||||
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
appCfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Login successful!")
|
||||
if cred.AccountID != "" {
|
||||
fmt.Printf("Account: %s\n", cred.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func authLoginPasteToken(provider string) {
|
||||
cred, err := auth.LoginPasteToken(provider, os.Stdin)
|
||||
if err != nil {
|
||||
fmt.Printf("Login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := auth.SetCredential(provider, cred); err != nil {
|
||||
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
appCfg.Providers.Anthropic.AuthMethod = "token"
|
||||
case "openai":
|
||||
appCfg.Providers.OpenAI.AuthMethod = "token"
|
||||
}
|
||||
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Token saved for %s!\n", provider)
|
||||
}
|
||||
|
||||
func authLogoutCmd() {
|
||||
provider := ""
|
||||
|
||||
args := os.Args[3:]
|
||||
for i := 0; i < len(args); i++ {
|
||||
switch args[i] {
|
||||
case "--provider", "-p":
|
||||
if i+1 < len(args) {
|
||||
provider = args[i+1]
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if provider != "" {
|
||||
if err := auth.DeleteCredential(provider); err != nil {
|
||||
fmt.Printf("Failed to remove credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
switch provider {
|
||||
case "openai":
|
||||
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||
case "anthropic":
|
||||
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||
}
|
||||
config.SaveConfig(getConfigPath(), appCfg)
|
||||
}
|
||||
|
||||
fmt.Printf("Logged out from %s\n", provider)
|
||||
} else {
|
||||
if err := auth.DeleteAllCredentials(); err != nil {
|
||||
fmt.Printf("Failed to remove credentials: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
appCfg, err := loadConfig()
|
||||
if err == nil {
|
||||
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||
config.SaveConfig(getConfigPath(), appCfg)
|
||||
}
|
||||
|
||||
fmt.Println("Logged out from all providers")
|
||||
}
|
||||
}
|
||||
|
||||
func authStatusCmd() {
|
||||
store, err := auth.LoadStore()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading auth store: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(store.Credentials) == 0 {
|
||||
fmt.Println("No authenticated providers.")
|
||||
fmt.Println("Run: picoclaw auth login --provider <name>")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("\nAuthenticated Providers:")
|
||||
fmt.Println("------------------------")
|
||||
for provider, cred := range store.Credentials {
|
||||
status := "active"
|
||||
if cred.IsExpired() {
|
||||
status = "expired"
|
||||
} else if cred.NeedsRefresh() {
|
||||
status = "needs refresh"
|
||||
}
|
||||
|
||||
fmt.Printf(" %s:\n", provider)
|
||||
fmt.Printf(" Method: %s\n", cred.AuthMethod)
|
||||
fmt.Printf(" Status: %s\n", status)
|
||||
if cred.AccountID != "" {
|
||||
fmt.Printf(" Account: %s\n", cred.AccountID)
|
||||
}
|
||||
if !cred.ExpiresAt.IsZero() {
|
||||
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
3
go.mod
3
go.mod
@@ -4,6 +4,7 @@ go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/adhocore/gronx v1.19.6
|
||||
github.com/anthropics/anthropic-sdk-go v1.22.1
|
||||
github.com/bwmarrin/discordgo v0.29.0
|
||||
github.com/caarlos0/env/v11 v11.3.1
|
||||
github.com/chzyer/readline v1.5.1
|
||||
@@ -11,6 +12,7 @@ require (
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/openai/openai-go/v3 v3.21.0
|
||||
github.com/tencent-connect/botgo v0.2.1
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
)
|
||||
@@ -22,6 +24,7 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/net v0.50.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
|
||||
11
go.sum
11
go.sum
@@ -1,6 +1,8 @@
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
|
||||
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
|
||||
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
|
||||
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
|
||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
|
||||
@@ -72,6 +74,8 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
|
||||
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
||||
github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc=
|
||||
github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -84,11 +88,12 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
|
||||
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
|
||||
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
@@ -97,6 +102,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
|
||||
358
pkg/auth/oauth.go
Normal file
358
pkg/auth/oauth.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OAuthProviderConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
Scopes string
|
||||
Port int
|
||||
}
|
||||
|
||||
func OpenAIOAuthConfig() OAuthProviderConfig {
|
||||
return OAuthProviderConfig{
|
||||
Issuer: "https://auth.openai.com",
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
Scopes: "openid profile email offline_access",
|
||||
Port: 1455,
|
||||
}
|
||||
}
|
||||
|
||||
func generateState() (string, error) {
|
||||
buf := make([]byte, 32)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
pkce, err := GeneratePKCE()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating PKCE: %w", err)
|
||||
}
|
||||
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating state: %w", err)
|
||||
}
|
||||
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
|
||||
|
||||
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
||||
|
||||
resultCh := make(chan callbackResult, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("state") != state {
|
||||
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
|
||||
http.Error(w, "State mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
errMsg := r.URL.Query().Get("error")
|
||||
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
|
||||
http.Error(w, "No authorization code received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
|
||||
resultCh <- callbackResult{code: code}
|
||||
})
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
|
||||
}
|
||||
|
||||
server := &http.Server{Handler: mux}
|
||||
go server.Serve(listener)
|
||||
defer func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
server.Shutdown(ctx)
|
||||
}()
|
||||
|
||||
if err := openBrowser(authURL); err != nil {
|
||||
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for authentication in browser...")
|
||||
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
return nil, result.err
|
||||
}
|
||||
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
|
||||
case <-time.After(5 * time.Minute):
|
||||
return nil, fmt.Errorf("authentication timed out after 5 minutes")
|
||||
}
|
||||
}
|
||||
|
||||
type callbackResult struct {
|
||||
code string
|
||||
err error
|
||||
}
|
||||
|
||||
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
reqBody, _ := json.Marshal(map[string]string{
|
||||
"client_id": cfg.ClientID,
|
||||
})
|
||||
|
||||
resp, err := http.Post(
|
||||
cfg.Issuer+"/api/accounts/deviceauth/usercode",
|
||||
"application/json",
|
||||
strings.NewReader(string(reqBody)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("requesting device code: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
||||
}
|
||||
|
||||
var deviceResp struct {
|
||||
DeviceAuthID string `json:"device_auth_id"`
|
||||
UserCode string `json:"user_code"`
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &deviceResp); err != nil {
|
||||
return nil, fmt.Errorf("parsing device code response: %w", err)
|
||||
}
|
||||
|
||||
if deviceResp.Interval < 1 {
|
||||
deviceResp.Interval = 5
|
||||
}
|
||||
|
||||
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
|
||||
cfg.Issuer, deviceResp.UserCode)
|
||||
|
||||
deadline := time.After(15 * time.Minute)
|
||||
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-deadline:
|
||||
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
|
||||
case <-ticker.C:
|
||||
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if cred != nil {
|
||||
return cred, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
|
||||
reqBody, _ := json.Marshal(map[string]string{
|
||||
"device_auth_id": deviceAuthID,
|
||||
"user_code": userCode,
|
||||
})
|
||||
|
||||
resp, err := http.Post(
|
||||
cfg.Issuer+"/api/accounts/deviceauth/token",
|
||||
"application/json",
|
||||
strings.NewReader(string(reqBody)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("pending")
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var tokenResp struct {
|
||||
AuthorizationCode string `json:"authorization_code"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redirectURI := cfg.Issuer + "/deviceauth/callback"
|
||||
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
|
||||
}
|
||||
|
||||
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||
if cred.RefreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"client_id": {cfg.ClientID},
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {cred.RefreshToken},
|
||||
"scope": {"openid profile email"},
|
||||
}
|
||||
|
||||
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
return parseTokenResponse(body, cred.Provider)
|
||||
}
|
||||
|
||||
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
||||
}
|
||||
|
||||
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
}
|
||||
return cfg.Issuer + "/authorize?" + params.Encode()
|
||||
}
|
||||
|
||||
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {code},
|
||||
"redirect_uri": {redirectURI},
|
||||
"client_id": {cfg.ClientID},
|
||||
"code_verifier": {codeVerifier},
|
||||
}
|
||||
|
||||
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||
}
|
||||
|
||||
return parseTokenResponse(body, "openai")
|
||||
}
|
||||
|
||||
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("parsing token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("no access token in response")
|
||||
}
|
||||
|
||||
var expiresAt time.Time
|
||||
if tokenResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
}
|
||||
|
||||
cred := &AuthCredential{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
Provider: provider,
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func extractAccountID(accessToken string) string {
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64URLDecode(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func base64URLDecode(s string) ([]byte, error) {
|
||||
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
|
||||
func openBrowser(url string) error {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return exec.Command("open", url).Start()
|
||||
case "linux":
|
||||
return exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
return exec.Command("cmd", "/c", "start", url).Start()
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
199
pkg/auth/oauth_test.go
Normal file
199
pkg/auth/oauth_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildAuthorizeURL(t *testing.T) {
|
||||
cfg := OAuthProviderConfig{
|
||||
Issuer: "https://auth.example.com",
|
||||
ClientID: "test-client-id",
|
||||
Scopes: "openid profile",
|
||||
Port: 1455,
|
||||
}
|
||||
pkce := PKCECodes{
|
||||
CodeVerifier: "test-verifier",
|
||||
CodeChallenge: "test-challenge",
|
||||
}
|
||||
|
||||
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||
|
||||
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
|
||||
t.Errorf("URL does not start with expected prefix: %s", u)
|
||||
}
|
||||
if !strings.Contains(u, "client_id=test-client-id") {
|
||||
t.Error("URL missing client_id")
|
||||
}
|
||||
if !strings.Contains(u, "code_challenge=test-challenge") {
|
||||
t.Error("URL missing code_challenge")
|
||||
}
|
||||
if !strings.Contains(u, "code_challenge_method=S256") {
|
||||
t.Error("URL missing code_challenge_method")
|
||||
}
|
||||
if !strings.Contains(u, "state=test-state") {
|
||||
t.Error("URL missing state")
|
||||
}
|
||||
if !strings.Contains(u, "response_type=code") {
|
||||
t.Error("URL missing response_type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponse(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"id_token": "test-id-token",
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
|
||||
cred, err := parseTokenResponse(body, "openai")
|
||||
if err != nil {
|
||||
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||
}
|
||||
|
||||
if cred.AccessToken != "test-access-token" {
|
||||
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token")
|
||||
}
|
||||
if cred.RefreshToken != "test-refresh-token" {
|
||||
t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token")
|
||||
}
|
||||
if cred.Provider != "openai" {
|
||||
t.Errorf("Provider = %q, want %q", cred.Provider, "openai")
|
||||
}
|
||||
if cred.AuthMethod != "oauth" {
|
||||
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
|
||||
}
|
||||
if cred.ExpiresAt.IsZero() {
|
||||
t.Error("ExpiresAt should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||
body := []byte(`{"refresh_token": "test"}`)
|
||||
_, err := parseTokenResponse(body, "openai")
|
||||
if err == nil {
|
||||
t.Error("expected error for missing access_token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExchangeCodeForTokens(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/oauth/token" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
r.ParseForm()
|
||||
if r.FormValue("grant_type") != "authorization_code" {
|
||||
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "mock-access-token",
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := OAuthProviderConfig{
|
||||
Issuer: server.URL,
|
||||
ClientID: "test-client",
|
||||
Scopes: "openid",
|
||||
Port: 1455,
|
||||
}
|
||||
|
||||
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
|
||||
if err != nil {
|
||||
t.Fatalf("exchangeCodeForTokens() error: %v", err)
|
||||
}
|
||||
|
||||
if cred.AccessToken != "mock-access-token" {
|
||||
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAccessToken(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/oauth/token" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
r.ParseForm()
|
||||
if r.FormValue("grant_type") != "refresh_token" {
|
||||
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "refreshed-access-token",
|
||||
"refresh_token": "refreshed-refresh-token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := OAuthProviderConfig{
|
||||
Issuer: server.URL,
|
||||
ClientID: "test-client",
|
||||
}
|
||||
|
||||
cred := &AuthCredential{
|
||||
AccessToken: "old-token",
|
||||
RefreshToken: "old-refresh-token",
|
||||
Provider: "openai",
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
refreshed, err := RefreshAccessToken(cred, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshAccessToken() error: %v", err)
|
||||
}
|
||||
|
||||
if refreshed.AccessToken != "refreshed-access-token" {
|
||||
t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token")
|
||||
}
|
||||
if refreshed.RefreshToken != "refreshed-refresh-token" {
|
||||
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
cred := &AuthCredential{
|
||||
AccessToken: "old-token",
|
||||
Provider: "openai",
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
_, err := RefreshAccessToken(cred, cfg)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing refresh token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthConfig(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
if cfg.Issuer != "https://auth.openai.com" {
|
||||
t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com")
|
||||
}
|
||||
if cfg.ClientID == "" {
|
||||
t.Error("ClientID is empty")
|
||||
}
|
||||
if cfg.Port != 1455 {
|
||||
t.Errorf("Port = %d, want 1455", cfg.Port)
|
||||
}
|
||||
}
|
||||
29
pkg/auth/pkce.go
Normal file
29
pkg/auth/pkce.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
type PKCECodes struct {
|
||||
CodeVerifier string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
func GeneratePKCE() (PKCECodes, error) {
|
||||
buf := make([]byte, 64)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return PKCECodes{}, err
|
||||
}
|
||||
|
||||
verifier := base64.RawURLEncoding.EncodeToString(buf)
|
||||
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
|
||||
return PKCECodes{
|
||||
CodeVerifier: verifier,
|
||||
CodeChallenge: challenge,
|
||||
}, nil
|
||||
}
|
||||
51
pkg/auth/pkce_test.go
Normal file
51
pkg/auth/pkce_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGeneratePKCE(t *testing.T) {
|
||||
codes, err := GeneratePKCE()
|
||||
if err != nil {
|
||||
t.Fatalf("GeneratePKCE() error: %v", err)
|
||||
}
|
||||
|
||||
if codes.CodeVerifier == "" {
|
||||
t.Fatal("CodeVerifier is empty")
|
||||
}
|
||||
if codes.CodeChallenge == "" {
|
||||
t.Fatal("CodeChallenge is empty")
|
||||
}
|
||||
|
||||
verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier)
|
||||
if err != nil {
|
||||
t.Fatalf("CodeVerifier is not valid base64url: %v", err)
|
||||
}
|
||||
if len(verifierBytes) != 64 {
|
||||
t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes))
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(codes.CodeVerifier))
|
||||
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
if codes.CodeChallenge != expectedChallenge {
|
||||
t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePKCEUniqueness(t *testing.T) {
|
||||
codes1, err := GeneratePKCE()
|
||||
if err != nil {
|
||||
t.Fatalf("GeneratePKCE() error: %v", err)
|
||||
}
|
||||
|
||||
codes2, err := GeneratePKCE()
|
||||
if err != nil {
|
||||
t.Fatalf("GeneratePKCE() error: %v", err)
|
||||
}
|
||||
|
||||
if codes1.CodeVerifier == codes2.CodeVerifier {
|
||||
t.Error("two GeneratePKCE() calls produced identical verifiers")
|
||||
}
|
||||
}
|
||||
112
pkg/auth/store.go
Normal file
112
pkg/auth/store.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuthCredential struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
Provider string `json:"provider"`
|
||||
AuthMethod string `json:"auth_method"`
|
||||
}
|
||||
|
||||
type AuthStore struct {
|
||||
Credentials map[string]*AuthCredential `json:"credentials"`
|
||||
}
|
||||
|
||||
func (c *AuthCredential) IsExpired() bool {
|
||||
if c.ExpiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(c.ExpiresAt)
|
||||
}
|
||||
|
||||
func (c *AuthCredential) NeedsRefresh() bool {
|
||||
if c.ExpiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().Add(5 * time.Minute).After(c.ExpiresAt)
|
||||
}
|
||||
|
||||
func authFilePath() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".picoclaw", "auth.json")
|
||||
}
|
||||
|
||||
func LoadStore() (*AuthStore, error) {
|
||||
path := authFilePath()
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var store AuthStore
|
||||
if err := json.Unmarshal(data, &store); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if store.Credentials == nil {
|
||||
store.Credentials = make(map[string]*AuthCredential)
|
||||
}
|
||||
return &store, nil
|
||||
}
|
||||
|
||||
func SaveStore(store *AuthStore) error {
|
||||
path := authFilePath()
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(store, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
func GetCredential(provider string) (*AuthCredential, error) {
|
||||
store, err := LoadStore()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cred, ok := store.Credentials[provider]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func SetCredential(provider string, cred *AuthCredential) error {
|
||||
store, err := LoadStore()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
store.Credentials[provider] = cred
|
||||
return SaveStore(store)
|
||||
}
|
||||
|
||||
func DeleteCredential(provider string) error {
|
||||
store, err := LoadStore()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
delete(store.Credentials, provider)
|
||||
return SaveStore(store)
|
||||
}
|
||||
|
||||
func DeleteAllCredentials() error {
|
||||
path := authFilePath()
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
189
pkg/auth/store_test.go
Normal file
189
pkg/auth/store_test.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAuthCredentialIsExpired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresAt time.Time
|
||||
want bool
|
||||
}{
|
||||
{"zero time", time.Time{}, false},
|
||||
{"future", time.Now().Add(time.Hour), false},
|
||||
{"past", time.Now().Add(-time.Hour), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &AuthCredential{ExpiresAt: tt.expiresAt}
|
||||
if got := c.IsExpired(); got != tt.want {
|
||||
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthCredentialNeedsRefresh(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresAt time.Time
|
||||
want bool
|
||||
}{
|
||||
{"zero time", time.Time{}, false},
|
||||
{"far future", time.Now().Add(time.Hour), false},
|
||||
{"within 5 min", time.Now().Add(3 * time.Minute), true},
|
||||
{"already expired", time.Now().Add(-time.Minute), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &AuthCredential{ExpiresAt: tt.expiresAt}
|
||||
if got := c.NeedsRefresh(); got != tt.want {
|
||||
t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreRoundtrip(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
origHome := os.Getenv("HOME")
|
||||
t.Setenv("HOME", tmpDir)
|
||||
defer os.Setenv("HOME", origHome)
|
||||
|
||||
cred := &AuthCredential{
|
||||
AccessToken: "test-access-token",
|
||||
RefreshToken: "test-refresh-token",
|
||||
AccountID: "acct-123",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second),
|
||||
Provider: "openai",
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
if err := SetCredential("openai", cred); err != nil {
|
||||
t.Fatalf("SetCredential() error: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := GetCredential("openai")
|
||||
if err != nil {
|
||||
t.Fatalf("GetCredential() error: %v", err)
|
||||
}
|
||||
if loaded == nil {
|
||||
t.Fatal("GetCredential() returned nil")
|
||||
}
|
||||
if loaded.AccessToken != cred.AccessToken {
|
||||
t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken)
|
||||
}
|
||||
if loaded.RefreshToken != cred.RefreshToken {
|
||||
t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken)
|
||||
}
|
||||
if loaded.Provider != cred.Provider {
|
||||
t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreFilePermissions(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
origHome := os.Getenv("HOME")
|
||||
t.Setenv("HOME", tmpDir)
|
||||
defer os.Setenv("HOME", origHome)
|
||||
|
||||
cred := &AuthCredential{
|
||||
AccessToken: "secret-token",
|
||||
Provider: "openai",
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
if err := SetCredential("openai", cred); err != nil {
|
||||
t.Fatalf("SetCredential() error: %v", err)
|
||||
}
|
||||
|
||||
path := filepath.Join(tmpDir, ".picoclaw", "auth.json")
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat() error: %v", err)
|
||||
}
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0600 {
|
||||
t.Errorf("file permissions = %o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreMultiProvider(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
origHome := os.Getenv("HOME")
|
||||
t.Setenv("HOME", tmpDir)
|
||||
defer os.Setenv("HOME", origHome)
|
||||
|
||||
openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"}
|
||||
anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"}
|
||||
|
||||
if err := SetCredential("openai", openaiCred); err != nil {
|
||||
t.Fatalf("SetCredential(openai) error: %v", err)
|
||||
}
|
||||
if err := SetCredential("anthropic", anthropicCred); err != nil {
|
||||
t.Fatalf("SetCredential(anthropic) error: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := GetCredential("openai")
|
||||
if err != nil {
|
||||
t.Fatalf("GetCredential(openai) error: %v", err)
|
||||
}
|
||||
if loaded.AccessToken != "openai-token" {
|
||||
t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token")
|
||||
}
|
||||
|
||||
loaded, err = GetCredential("anthropic")
|
||||
if err != nil {
|
||||
t.Fatalf("GetCredential(anthropic) error: %v", err)
|
||||
}
|
||||
if loaded.AccessToken != "anthropic-token" {
|
||||
t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteCredential(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
origHome := os.Getenv("HOME")
|
||||
t.Setenv("HOME", tmpDir)
|
||||
defer os.Setenv("HOME", origHome)
|
||||
|
||||
cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"}
|
||||
if err := SetCredential("openai", cred); err != nil {
|
||||
t.Fatalf("SetCredential() error: %v", err)
|
||||
}
|
||||
|
||||
if err := DeleteCredential("openai"); err != nil {
|
||||
t.Fatalf("DeleteCredential() error: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := GetCredential("openai")
|
||||
if err != nil {
|
||||
t.Fatalf("GetCredential() error: %v", err)
|
||||
}
|
||||
if loaded != nil {
|
||||
t.Error("expected nil after delete")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadStoreEmpty(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
origHome := os.Getenv("HOME")
|
||||
t.Setenv("HOME", tmpDir)
|
||||
defer os.Setenv("HOME", origHome)
|
||||
|
||||
store, err := LoadStore()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadStore() error: %v", err)
|
||||
}
|
||||
if store == nil {
|
||||
t.Fatal("LoadStore() returned nil")
|
||||
}
|
||||
if len(store.Credentials) != 0 {
|
||||
t.Errorf("expected empty credentials, got %d", len(store.Credentials))
|
||||
}
|
||||
}
|
||||
43
pkg/auth/token.go
Normal file
43
pkg/auth/token.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
|
||||
fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider))
|
||||
fmt.Print("> ")
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("reading token: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("no input received")
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(scanner.Text())
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("token cannot be empty")
|
||||
}
|
||||
|
||||
return &AuthCredential{
|
||||
AccessToken: token,
|
||||
Provider: provider,
|
||||
AuthMethod: "token",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func providerDisplayName(provider string) string {
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return "console.anthropic.com"
|
||||
case "openai":
|
||||
return "platform.openai.com"
|
||||
default:
|
||||
return provider
|
||||
}
|
||||
}
|
||||
@@ -101,6 +101,7 @@ type ProvidersConfig struct {
|
||||
type ProviderConfig struct {
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
|
||||
207
pkg/providers/claude_provider.go
Normal file
207
pkg/providers/claude_provider.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
)
|
||||
|
||||
type ClaudeProvider struct {
|
||||
client *anthropic.Client
|
||||
tokenSource func() (string, error)
|
||||
}
|
||||
|
||||
func NewClaudeProvider(token string) *ClaudeProvider {
|
||||
client := anthropic.NewClient(
|
||||
option.WithAuthToken(token),
|
||||
option.WithBaseURL("https://api.anthropic.com"),
|
||||
)
|
||||
return &ClaudeProvider{client: &client}
|
||||
}
|
||||
|
||||
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
|
||||
p := NewClaudeProvider(token)
|
||||
p.tokenSource = tokenSource
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
var opts []option.RequestOption
|
||||
if p.tokenSource != nil {
|
||||
tok, err := p.tokenSource()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
opts = append(opts, option.WithAuthToken(tok))
|
||||
}
|
||||
|
||||
params, err := buildClaudeParams(messages, tools, model, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := p.client.Messages.New(ctx, params, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("claude API call: %w", err)
|
||||
}
|
||||
|
||||
return parseClaudeResponse(resp), nil
|
||||
}
|
||||
|
||||
func (p *ClaudeProvider) GetDefaultModel() string {
|
||||
return "claude-sonnet-4-5-20250929"
|
||||
}
|
||||
|
||||
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
|
||||
var system []anthropic.TextBlockParam
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "assistant":
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
var blocks []anthropic.ContentBlockParamUnion
|
||||
if msg.Content != "" {
|
||||
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
|
||||
}
|
||||
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||
} else {
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
|
||||
)
|
||||
}
|
||||
case "tool":
|
||||
anthropicMessages = append(anthropicMessages,
|
||||
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
maxTokens := int64(4096)
|
||||
if mt, ok := options["max_tokens"].(int); ok {
|
||||
maxTokens = int64(mt)
|
||||
}
|
||||
|
||||
params := anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(model),
|
||||
Messages: anthropicMessages,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
if len(system) > 0 {
|
||||
params.System = system
|
||||
}
|
||||
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
params.Temperature = anthropic.Float(temp)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForClaude(tools)
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||
result := make([]anthropic.ToolUnionParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
tool := anthropic.ToolParam{
|
||||
Name: t.Function.Name,
|
||||
InputSchema: anthropic.ToolInputSchemaParam{
|
||||
Properties: t.Function.Parameters["properties"],
|
||||
},
|
||||
}
|
||||
if desc := t.Function.Description; desc != "" {
|
||||
tool.Description = anthropic.String(desc)
|
||||
}
|
||||
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
|
||||
required := make([]string, 0, len(req))
|
||||
for _, r := range req {
|
||||
if s, ok := r.(string); ok {
|
||||
required = append(required, s)
|
||||
}
|
||||
}
|
||||
tool.InputSchema.Required = required
|
||||
}
|
||||
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
|
||||
var content string
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
tb := block.AsText()
|
||||
content += tb.Text
|
||||
case "tool_use":
|
||||
tu := block.AsToolUse()
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal(tu.Input, &args); err != nil {
|
||||
args = map[string]interface{}{"raw": string(tu.Input)}
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: tu.ID,
|
||||
Name: tu.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
switch resp.StopReason {
|
||||
case anthropic.StopReasonToolUse:
|
||||
finishReason = "tool_calls"
|
||||
case anthropic.StopReasonMaxTokens:
|
||||
finishReason = "length"
|
||||
case anthropic.StopReasonEndTurn:
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: &UsageInfo{
|
||||
PromptTokens: int(resp.Usage.InputTokens),
|
||||
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createClaudeTokenSource() func() (string, error) {
|
||||
return func() (string, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return cred.AccessToken, nil
|
||||
}
|
||||
}
|
||||
210
pkg/providers/claude_provider_test.go
Normal file
210
pkg/providers/claude_provider_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
|
||||
)
|
||||
|
||||
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
|
||||
"max_tokens": 1024,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if string(params.Model) != "claude-sonnet-4-5-20250929" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
|
||||
}
|
||||
if params.MaxTokens != 1024 {
|
||||
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if len(params.System) != 1 {
|
||||
t.Fatalf("len(System) = %d, want 1", len(params.System))
|
||||
}
|
||||
if params.System[0].Text != "You are helpful" {
|
||||
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
|
||||
}
|
||||
if len(params.Messages) != 1 {
|
||||
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]interface{}{"city": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if len(params.Messages) != 3 {
|
||||
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClaudeParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather for a city",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []interface{}{"city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||
}
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeResponse_TextOnly(t *testing.T) {
|
||||
resp := &anthropic.Message{
|
||||
Content: []anthropic.ContentBlockUnion{},
|
||||
Usage: anthropic.Usage{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
},
|
||||
}
|
||||
result := parseClaudeResponse(resp)
|
||||
if result.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
|
||||
}
|
||||
if result.Usage.CompletionTokens != 20 {
|
||||
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
|
||||
}
|
||||
if result.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseClaudeResponse_StopReasons(t *testing.T) {
|
||||
tests := []struct {
|
||||
stopReason anthropic.StopReason
|
||||
want string
|
||||
}{
|
||||
{anthropic.StopReasonEndTurn, "stop"},
|
||||
{anthropic.StopReasonMaxTokens, "length"},
|
||||
{anthropic.StopReasonToolUse, "tool_calls"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
resp := &anthropic.Message{
|
||||
StopReason: tt.stopReason,
|
||||
}
|
||||
result := parseClaudeResponse(resp)
|
||||
if result.FinishReason != tt.want {
|
||||
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/messages" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "msg_test",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": reqBody["model"],
|
||||
"stop_reason": "end_turn",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "text", "text": "Hello! How can I help you?"},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 15,
|
||||
"output_tokens": 8,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewClaudeProvider("test-token")
|
||||
provider.client = createAnthropicTestClient(server.URL, "test-token")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello! How can I help you?" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 15 {
|
||||
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewClaudeProvider("test-token")
|
||||
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
|
||||
}
|
||||
}
|
||||
|
||||
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
|
||||
c := anthropic.NewClient(
|
||||
anthropicoption.WithAuthToken(token),
|
||||
anthropicoption.WithBaseURL(baseURL),
|
||||
)
|
||||
return &c
|
||||
}
|
||||
248
pkg/providers/codex_provider.go
Normal file
248
pkg/providers/codex_provider.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
)
|
||||
|
||||
type CodexProvider struct {
|
||||
client *openai.Client
|
||||
accountID string
|
||||
tokenSource func() (string, string, error)
|
||||
}
|
||||
|
||||
func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||
opts := []option.RequestOption{
|
||||
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
||||
option.WithAPIKey(token),
|
||||
}
|
||||
if accountID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||
}
|
||||
client := openai.NewClient(opts...)
|
||||
return &CodexProvider{
|
||||
client: &client,
|
||||
accountID: accountID,
|
||||
}
|
||||
}
|
||||
|
||||
func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider {
|
||||
p := NewCodexProvider(token, accountID)
|
||||
p.tokenSource = tokenSource
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
var opts []option.RequestOption
|
||||
if p.tokenSource != nil {
|
||||
tok, accID, err := p.tokenSource()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
opts = append(opts, option.WithAPIKey(tok))
|
||||
if accID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
|
||||
}
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, tools, model, options)
|
||||
|
||||
resp, err := p.client.Responses.New(ctx, params, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codex API call: %w", err)
|
||||
}
|
||||
|
||||
return parseCodexResponse(resp), nil
|
||||
}
|
||||
|
||||
func (p *CodexProvider) GetDefaultModel() string {
|
||||
return "gpt-4o"
|
||||
}
|
||||
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||
var inputItems responses.ResponseInputParam
|
||||
var instructions string
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
instructions = msg.Content
|
||||
case "user":
|
||||
if msg.ToolCallID != "" {
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||
CallID: msg.ToolCallID,
|
||||
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfMessage: &responses.EasyInputMessageParam{
|
||||
Role: responses.EasyInputMessageRoleUser,
|
||||
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
case "assistant":
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if msg.Content != "" {
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfMessage: &responses.EasyInputMessageParam{
|
||||
Role: responses.EasyInputMessageRoleAssistant,
|
||||
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||
CallID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfMessage: &responses.EasyInputMessageParam{
|
||||
Role: responses.EasyInputMessageRoleAssistant,
|
||||
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
case "tool":
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||
CallID: msg.ToolCallID,
|
||||
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
params := responses.ResponseNewParams{
|
||||
Model: model,
|
||||
Input: responses.ResponseNewParamsInputUnion{
|
||||
OfInputItemList: inputItems,
|
||||
},
|
||||
Store: openai.Opt(false),
|
||||
}
|
||||
|
||||
if instructions != "" {
|
||||
params.Instructions = openai.Opt(instructions)
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||
}
|
||||
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
params.Temperature = openai.Opt(temp)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForCodex(tools)
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||
result := make([]responses.ToolUnionParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
ft := responses.FunctionToolParam{
|
||||
Name: t.Function.Name,
|
||||
Parameters: t.Function.Parameters,
|
||||
Strict: openai.Opt(false),
|
||||
}
|
||||
if t.Function.Description != "" {
|
||||
ft.Description = openai.Opt(t.Function.Description)
|
||||
}
|
||||
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseCodexResponse(resp *responses.Response) *LLMResponse {
|
||||
var content strings.Builder
|
||||
var toolCalls []ToolCall
|
||||
|
||||
for _, item := range resp.Output {
|
||||
switch item.Type {
|
||||
case "message":
|
||||
for _, c := range item.Content {
|
||||
if c.Type == "output_text" {
|
||||
content.WriteString(c.Text)
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
|
||||
args = map[string]interface{}{"raw": item.Arguments}
|
||||
}
|
||||
toolCalls = append(toolCalls, ToolCall{
|
||||
ID: item.CallID,
|
||||
Name: item.Name,
|
||||
Arguments: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
if resp.Status == "incomplete" {
|
||||
finishReason = "length"
|
||||
}
|
||||
|
||||
var usage *UsageInfo
|
||||
if resp.Usage.TotalTokens > 0 {
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: int(resp.Usage.InputTokens),
|
||||
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||
TotalTokens: int(resp.Usage.TotalTokens),
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: content.String(),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
func createCodexTokenSource() func() (string, string, error) {
|
||||
return func() (string, string, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
|
||||
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
|
||||
oauthCfg := auth.OpenAIOAuthConfig()
|
||||
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
if err := auth.SetCredential("openai", refreshed); err != nil {
|
||||
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
||||
}
|
||||
return refreshed.AccessToken, refreshed.AccountID, nil
|
||||
}
|
||||
|
||||
return cred.AccessToken, cred.AccountID, nil
|
||||
}
|
||||
}
|
||||
264
pkg/providers/codex_provider_test.go
Normal file
264
pkg/providers/codex_provider_test.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
openaiopt "github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
)
|
||||
|
||||
func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||
"max_tokens": 2048,
|
||||
})
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
if !params.Instructions.Valid() {
|
||||
t.Fatal("Instructions should be set")
|
||||
}
|
||||
if params.Instructions.Or("") != "You are helpful" {
|
||||
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "What's the weather?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{
|
||||
{ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
if params.Input.OfInputItemList == nil {
|
||||
t.Fatal("Input.OfInputItemList should not be nil")
|
||||
}
|
||||
if len(params.Input.OfInputItemList) != 3 {
|
||||
t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
|
||||
if len(params.Tools) != 1 {
|
||||
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||
}
|
||||
if params.Tools[0].OfFunction == nil {
|
||||
t.Fatal("Tool should be a function tool")
|
||||
}
|
||||
if params.Tools[0].OfFunction.Name != "get_weather" {
|
||||
t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
|
||||
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
|
||||
if !params.Store.Valid() || params.Store.Or(true) != false {
|
||||
t.Error("Store should be explicitly set to false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexResponse_TextOutput(t *testing.T) {
|
||||
respJSON := `{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": [
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [
|
||||
{"type": "output_text", "text": "Hello there!"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens_details": {"reasoning_tokens": 0}
|
||||
}
|
||||
}`
|
||||
|
||||
var resp responses.Response
|
||||
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
result := parseCodexResponse(&resp)
|
||||
if result.Content != "Hello there!" {
|
||||
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
|
||||
}
|
||||
if result.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||
}
|
||||
if result.Usage.TotalTokens != 15 {
|
||||
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexResponse_FunctionCall(t *testing.T) {
|
||||
respJSON := `{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": [
|
||||
{
|
||||
"id": "fc_1",
|
||||
"type": "function_call",
|
||||
"call_id": "call_abc",
|
||||
"name": "get_weather",
|
||||
"arguments": "{\"city\":\"SF\"}",
|
||||
"status": "completed"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 8,
|
||||
"total_tokens": 18,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens_details": {"reasoning_tokens": 0}
|
||||
}
|
||||
}`
|
||||
|
||||
var resp responses.Response
|
||||
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
|
||||
result := parseCodexResponse(&resp)
|
||||
if len(result.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
|
||||
}
|
||||
tc := result.ToolCalls[0]
|
||||
if tc.Name != "get_weather" {
|
||||
t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather")
|
||||
}
|
||||
if tc.ID != "call_abc" {
|
||||
t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc")
|
||||
}
|
||||
if tc.Arguments["city"] != "SF" {
|
||||
t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"])
|
||||
}
|
||||
if result.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
|
||||
http.Error(w, "missing account id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 6,
|
||||
"total_tokens": 18,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("test-token", "acc-123")
|
||||
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage.TotalTokens != 18 {
|
||||
t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewCodexProvider("test-token", "")
|
||||
if got := p.GetDefaultModel(); got != "gpt-4o" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
|
||||
}
|
||||
}
|
||||
|
||||
func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
|
||||
opts := []openaiopt.RequestOption{
|
||||
openaiopt.WithBaseURL(baseURL),
|
||||
openaiopt.WithAPIKey(token),
|
||||
}
|
||||
if accountID != "" {
|
||||
opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID))
|
||||
}
|
||||
c := openai.NewClient(opts...)
|
||||
return &c
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
@@ -74,8 +75,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
authHeader := "Bearer " + p.apiKey
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
@@ -170,6 +170,28 @@ func (p *HTTPProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("anthropic")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||
}
|
||||
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||
}
|
||||
|
||||
func createCodexAuthProvider() (LLMProvider, error) {
|
||||
cred, err := auth.GetCredential("openai")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||
}
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||
}
|
||||
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||
}
|
||||
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
model := cfg.Agents.Defaults.Model
|
||||
|
||||
@@ -186,14 +208,20 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
apiBase = "https://openrouter.ai/api/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && cfg.Providers.Anthropic.APIKey != "":
|
||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||
return createClaudeAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.Anthropic.APIKey
|
||||
apiBase = cfg.Providers.Anthropic.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.anthropic.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && cfg.Providers.OpenAI.APIKey != "":
|
||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
}
|
||||
apiKey = cfg.Providers.OpenAI.APIKey
|
||||
apiBase = cfg.Providers.OpenAI.APIBase
|
||||
if apiBase == "" {
|
||||
|
||||
Reference in New Issue
Block a user