diff --git a/README.md b/README.md index 0bfac28..b1ccf6b 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ Search across your company's knowledge, chat with Glean Assistant, manage the fu - [Quick Start](#quick-start) - [Why Glean CLI?](#why-glean-cli) - [Authentication](#authentication) - - [OAuth (recommended)](#oauth-recommended) - [API Token (CI/CD)](#api-token-cicd) + - [Credential resolution order](#credential-resolution-order) - [Interactive TUI](#interactive-tui) - [Keyboard Shortcuts](#keyboard-shortcuts) - [Slash Commands](#slash-commands) @@ -95,21 +95,25 @@ glean search "engineering docs" --output ndjson | jq .title ## Authentication -### OAuth (recommended) - ```bash snippet=readme/snippet-04.sh -glean auth login # opens browser, completes PKCE flow +glean auth login # interactive login (detects the best method automatically) glean auth status # verify credentials, host, and token expiry glean auth logout # remove all stored credentials ``` -OAuth uses PKCE with Dynamic Client Registration — no client ID required. Tokens are stored securely in the system keyring and refreshed automatically. +`glean auth login` detects the right authentication method for your environment automatically: + +| Method | When it's used | What happens | +| --- | --- | --- | +| **Browser login** | Default for most Glean instances | Opens your browser, you approve, done | +| **Device code login** | Organizations using an external IdP (e.g. Okta) | Prints a URL and code — open the URL, enter the code | +| **API token** | Instances without OAuth support | Prompts you to paste a token from Glean Admin | -For instances that don't support OAuth, `auth login` falls back to prompting for an API token. +You don't need to choose — `auth login` tries each method in order and uses the first one that works. Tokens are stored securely in the system keyring and refreshed automatically. ### API Token (CI/CD) -Set credentials via environment variables — no interactive login needed: +For non-interactive environments, set credentials via environment variables: ```bash snippet=readme/snippet-05.sh export GLEAN_API_TOKEN=your-token @@ -117,6 +121,10 @@ export GLEAN_HOST=your-company-be.glean.com glean search "test" ``` +API tokens are scoped to an individual user account. Generate one from **Glean Admin → Settings → API Tokens**. + +### Credential resolution order + Credentials are resolved in this order: environment variables → system keyring → `~/.glean/config.json`. ## Interactive TUI diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 3afac78..d9cb4db 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -6,6 +6,7 @@ import ( _ "embed" "encoding/base64" "encoding/json" + "errors" "fmt" "net" "net/http" @@ -30,14 +31,24 @@ var ( dcrLog = debug.New("auth:dcr") tokenLog = debug.New("auth:token") emailLog = debug.New("auth:email") + deviceLog = debug.New("auth:device") ) //go:embed success.html var successHTML string -// Login performs the full OAuth 2.0 PKCE login flow for the configured Glean host. -// If the host is not configured, prompts for a work email and auto-discovers it. -// If the instance doesn't support OAuth, falls back to an inline API token prompt. +// errNoOAuthClient is returned by dcrOrStaticClient when neither DCR nor a +// static client is available. Login uses this to decide whether device flow +// is an appropriate fallback (as opposed to transient failures like network +// timeouts or the user closing their browser). +var errNoOAuthClient = errors.New("no OAuth client available") + +// Login performs the full OAuth 2.0 login flow for the configured Glean host. +// +// Strategy (in order): +// 1. Authorization Code + PKCE via DCR or static client +// 2. Device Authorization Grant (RFC 8628) using the Glean-advertised client ID +// 3. Inline API token prompt when OAuth is not available at all func Login(ctx context.Context) error { loginLog.Log("starting login flow") @@ -47,45 +58,61 @@ func Login(ctx context.Context) error { } loginLog.Log("host resolved: %s", host) - provider, endpoint, registrationEndpoint, err := discover(ctx, host) + disc, err := discover(ctx, host) if err != nil { loginLog.Log("OAuth discovery failed, falling back to API token: %v", err) fmt.Fprintf(os.Stderr, "\nOAuth discovery failed: %v\n", err) return promptForAPIToken(host) } - loginLog.Log("OAuth discovery succeeded: auth=%s token=%s registration=%s", endpoint.AuthURL, endpoint.TokenURL, registrationEndpoint) + loginLog.Log("OAuth discovery succeeded: auth=%s token=%s registration=%s", disc.Endpoint.AuthURL, disc.Endpoint.TokenURL, disc.RegistrationEndpoint) + + // Try DCR / static client first (standard authorization code flow). + loginLog.Log("attempting authorization code + PKCE flow") + authCodeErr := tryAuthCodeLogin(ctx, host, disc) + if authCodeErr == nil { + return nil + } + loginLog.Log("auth code flow failed: %v", authCodeErr) + + // Only fall back to device flow when the auth code flow failed because no + // OAuth client could be obtained (DCR unsupported + no static client). + // Transient failures (network, user closing browser, port conflicts) should + // not silently switch to a different grant type. + canDeviceFlow := disc.DeviceFlowClientID != "" && disc.DeviceAuthEndpoint != "" + if errors.Is(authCodeErr, errNoOAuthClient) && canDeviceFlow { + loginLog.Log("falling back to device authorization grant (client_id=%s)", disc.DeviceFlowClientID) + fmt.Fprintf(os.Stderr, "\nYour SSO provider requires device-based login.\n") + return deviceFlowLogin(ctx, host, disc) + } + + return fmt.Errorf("authentication failed: %w", authCodeErr) +} - // Find a free port for the local callback server. - // This must happen before DCR so we register the exact redirect URI - // that oauth2cli will use — a mismatch causes a silent hang. +// tryAuthCodeLogin attempts the Authorization Code + PKCE flow via DCR or static client. +func tryAuthCodeLogin(ctx context.Context, host string, disc *discoveryResult) error { port, err := findFreePort() if err != nil { return fmt.Errorf("finding callback port: %w", err) } redirectURI := fmt.Sprintf("http://127.0.0.1:%d/glean-cli-callback", port) - // Always do fresh DCR per login — the redirect URI (port) changes each time. - clientID, clientSecret, err := dcrOrStaticClient(ctx, host, registrationEndpoint, redirectURI) + clientID, clientSecret, err := dcrOrStaticClient(ctx, host, disc.RegistrationEndpoint, redirectURI) if err != nil { - return fmt.Errorf("resolving OAuth client: %w", err) + return err } verifier := oauth2.GenerateVerifier() - scopes := resolveScopes(provider) + scopes := resolveScopes(disc.Provider) loginLog.Log("requesting scopes: %v", scopes) oauthCfg := oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, - Endpoint: endpoint, + Endpoint: disc.Endpoint, Scopes: scopes, RedirectURL: redirectURI, } - // oauth2cli v1.15.1 does not open the browser itself — the caller must do it. - // LocalServerReadyChan receives the local server URL once the callback server - // is ready. We open the browser to that URL (which the local server redirects - // to the real OAuth page), and also print the direct auth URL as a fallback. state := oauth2.GenerateVerifier()[:20] authURL := oauthCfg.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) @@ -97,7 +124,6 @@ func Login(ctx context.Context) error { fmt.Printf("If your browser doesn't open, visit:\n %s\n\n", authURL) fmt.Printf("Waiting for you to complete login in the browser…\n") if err := browser.OpenURL(localURL); err != nil { - // Browser failed to open — the printed URL is the fallback. fmt.Printf("(Could not open browser automatically: %v)\n", err) } case <-ctx.Done(): @@ -121,7 +147,14 @@ func Login(ctx context.Context) error { return fmt.Errorf("OAuth login failed: %w", err) } - email := extractEmailFromToken(ctx, provider, clientID, token) + return saveAndPrintToken(ctx, host, disc, oauthCfg.ClientID, token) +} + +// saveAndPrintToken persists the OAuth token and client, then prints a success message. +func saveAndPrintToken(ctx context.Context, host string, disc *discoveryResult, clientID string, token *oauth2.Token) error { + _ = SaveClient(host, &StoredClient{ClientID: clientID}) + + email := extractEmailFromToken(ctx, disc.Provider, clientID, token) stored := &StoredTokens{ AccessToken: token.AccessToken, @@ -129,7 +162,7 @@ func Login(ctx context.Context) error { Expiry: token.Expiry, Email: email, TokenType: token.TokenType, - TokenEndpoint: oauthCfg.Endpoint.TokenURL, // enables future token refresh + TokenEndpoint: disc.Endpoint.TokenURL, } if err := persistLoginState(host, stored); err != nil { return err @@ -356,6 +389,15 @@ func resolveHost(ctx context.Context) (string, error) { return host, nil } +// discoveryResult holds all OAuth metadata discovered for a Glean backend. +type discoveryResult struct { + Provider *oidc.Provider + Endpoint oauth2.Endpoint + RegistrationEndpoint string + DeviceFlowClientID string + DeviceAuthEndpoint string +} + // discover resolves the OAuth2 endpoint and registration endpoint for the Glean backend. // // Strategy: @@ -363,16 +405,13 @@ func resolveHost(ctx context.Context) (string, error) { // 2. Try OIDC discovery (oidc.NewProvider) for full OIDC support // 3. Fall back to RFC 8414 auth server metadata when OIDC is unavailable // (Glean uses RFC 8414 but does not serve /.well-known/openid-configuration) -// -// Returns (provider, oauth2Endpoint, registrationEndpoint, error). -// provider is nil when only RFC 8414 discovery succeeded. -func discover(ctx context.Context, host string) (*oidc.Provider, oauth2.Endpoint, string, error) { +func discover(ctx context.Context, host string) (*discoveryResult, error) { baseURL := "https://" + host discoveryLog.Log("fetching protected resource metadata: %s", baseURL) meta, err := fetchProtectedResource(ctx, baseURL) if err != nil { discoveryLog.Log("protected resource metadata failed: %v", err) - return nil, oauth2.Endpoint{}, "", err + return nil, err } issuer := meta.AuthorizationServers[0] @@ -383,31 +422,55 @@ func discover(ctx context.Context, host string) (*oidc.Provider, oauth2.Endpoint provider, err := oidc.NewProvider(ctx, issuer) if err == nil { discoveryLog.Log("OIDC discovery succeeded") - // Still need registration_endpoint, which oidc.Provider doesn't expose. - authMeta, _ := fetchAuthServerMetadata(ctx, issuer) - regEndpoint := "" - if authMeta != nil { - regEndpoint = authMeta.RegistrationEndpoint + res := &discoveryResult{Provider: provider, Endpoint: provider.Endpoint()} + res.DeviceFlowClientID = meta.GleanDeviceFlowClientID + + // Extract device_authorization_endpoint from OIDC provider claims + // (RFC 8414 metadata may omit it even when OIDC metadata includes it). + var providerClaims struct { + RegistrationEndpoint string `json:"registration_endpoint"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` + } + if err := provider.Claims(&providerClaims); err == nil { + res.RegistrationEndpoint = providerClaims.RegistrationEndpoint + res.DeviceAuthEndpoint = providerClaims.DeviceAuthorizationEndpoint } - return provider, provider.Endpoint(), regEndpoint, nil + + // Supplement from RFC 8414 if OIDC claims were incomplete. + if res.RegistrationEndpoint == "" || res.DeviceAuthEndpoint == "" { + if authMeta, err := fetchAuthServerMetadata(ctx, issuer); err == nil { + if res.RegistrationEndpoint == "" { + res.RegistrationEndpoint = authMeta.RegistrationEndpoint + } + if res.DeviceAuthEndpoint == "" { + res.DeviceAuthEndpoint = authMeta.DeviceAuthorizationEndpoint + } + } + } + return res, nil } discoveryLog.Log("OIDC discovery failed: %v, falling back to RFC 8414", err) // Fall back to RFC 8414 auth server metadata. authMeta, err := fetchAuthServerMetadata(ctx, issuer) if err != nil { - return nil, oauth2.Endpoint{}, "", fmt.Errorf("OAuth discovery failed for %s: %w", issuer, err) + return nil, fmt.Errorf("OAuth discovery failed for %s: %w", issuer, err) } if authMeta.AuthorizationEndpoint == "" || authMeta.TokenEndpoint == "" { discoveryLog.Log("RFC 8414 metadata incomplete: auth=%q token=%q", authMeta.AuthorizationEndpoint, authMeta.TokenEndpoint) - return nil, oauth2.Endpoint{}, "", fmt.Errorf("OAuth metadata missing required endpoints for %s", issuer) + return nil, fmt.Errorf("OAuth metadata missing required endpoints for %s", issuer) } discoveryLog.Log("RFC 8414 discovery succeeded: auth=%s token=%s", authMeta.AuthorizationEndpoint, authMeta.TokenEndpoint) - return nil, oauth2.Endpoint{ - AuthURL: authMeta.AuthorizationEndpoint, - TokenURL: authMeta.TokenEndpoint, - }, authMeta.RegistrationEndpoint, nil + return &discoveryResult{ + Endpoint: oauth2.Endpoint{ + AuthURL: authMeta.AuthorizationEndpoint, + TokenURL: authMeta.TokenEndpoint, + }, + RegistrationEndpoint: authMeta.RegistrationEndpoint, + DeviceFlowClientID: meta.GleanDeviceFlowClientID, + DeviceAuthEndpoint: authMeta.DeviceAuthorizationEndpoint, + }, nil } // dcrOrStaticClient resolves the OAuth client_id/secret for a login session. @@ -416,6 +479,7 @@ func discover(ctx context.Context, host string) (*oidc.Provider, oauth2.Endpoint // credentials can be reused for token refresh later. // Falls back to a static client configured via glean config --oauth-client-id. func dcrOrStaticClient(ctx context.Context, host, registrationEndpoint, redirectURI string) (string, string, error) { + var dcrErr error if registrationEndpoint != "" { dcrLog.Log("registering client at %s with redirect %s", registrationEndpoint, redirectURI) cl, err := registerClient(ctx, registrationEndpoint, redirectURI) @@ -426,7 +490,7 @@ func dcrOrStaticClient(ctx context.Context, host, registrationEndpoint, redirect } return cl.ClientID, cl.ClientSecret, nil } - // DCR failed — log and fall through to static client + dcrErr = err dcrLog.Log("DCR failed: %v, trying static client", err) fmt.Printf("Note: dynamic client registration failed (%v), trying static client\n", err) } else { @@ -439,7 +503,10 @@ func dcrOrStaticClient(ctx context.Context, host, registrationEndpoint, redirect return cfg.OAuthClientID, cfg.OAuthClientSecret, nil } - return "", "", fmt.Errorf("no OAuth client available — dynamic client registration failed and no static client is configured") + if dcrErr != nil { + return "", "", fmt.Errorf("%w: dynamic client registration failed (%v) and no static client is configured", errNoOAuthClient, dcrErr) + } + return "", "", fmt.Errorf("%w: no registration endpoint and no static client configured", errNoOAuthClient) } // resolveScopes returns the appropriate OAuth scopes for the given provider. @@ -517,10 +584,11 @@ func fetchAuthServerMetadata(ctx context.Context, issuer string) (*authServerMet } type authServerMeta struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"` } // extractEmailFromToken pulls the user email from the token. diff --git a/internal/auth/auth_fallback_test.go b/internal/auth/auth_fallback_test.go new file mode 100644 index 0000000..61efca7 --- /dev/null +++ b/internal/auth/auth_fallback_test.go @@ -0,0 +1,94 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/gleanwork/glean-cli/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDcrOrStaticClient_NoClientAvailable(t *testing.T) { + t.Setenv("GLEAN_HOST", "") + config.ConfigPath = t.TempDir() + "/config.json" + + _, _, err := dcrOrStaticClient(context.Background(), "test-host", "", "http://127.0.0.1:9999/callback") + require.Error(t, err) + assert.True(t, errors.Is(err, errNoOAuthClient), "expected errNoOAuthClient, got: %v", err) +} + +func TestDcrOrStaticClient_DCRFails_NoStaticClient(t *testing.T) { + t.Setenv("GLEAN_HOST", "") + config.ConfigPath = t.TempDir() + "/config.json" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + _, _, err := dcrOrStaticClient(context.Background(), "test-host", srv.URL, "http://127.0.0.1:9999/callback") + require.Error(t, err) + assert.True(t, errors.Is(err, errNoOAuthClient), + "DCR rejection (403) with no static client means no OAuth client is available") + assert.Contains(t, err.Error(), "dynamic client registration failed") +} + +func TestDcrOrStaticClient_DCRFails_StaticClientFallback(t *testing.T) { + dir := t.TempDir() + config.ConfigPath = dir + "/config.json" + t.Setenv("GLEAN_HOST", "test-host") + + cfgData, _ := json.Marshal(map[string]string{ + "host": "test-host", + "oauth_client_id": "static-id", + "oauth_client_secret": "static-secret", + }) + require.NoError(t, os.WriteFile(config.ConfigPath, cfgData, 0o600)) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer srv.Close() + + clientID, clientSecret, err := dcrOrStaticClient(context.Background(), "test-host", srv.URL, "http://127.0.0.1:9999/callback") + require.NoError(t, err) + assert.Equal(t, "static-id", clientID) + assert.Equal(t, "static-secret", clientSecret) +} + +func TestDcrOrStaticClient_DCRSucceeds(t *testing.T) { + dir := t.TempDir() + config.ConfigPath = dir + "/config.json" + setStoragePath(t, dir) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(map[string]string{ + "client_id": "dcr-id", + "client_secret": "dcr-secret", + }) + })) + defer srv.Close() + + clientID, clientSecret, err := dcrOrStaticClient(context.Background(), "test-host", srv.URL, "http://127.0.0.1:9999/callback") + require.NoError(t, err) + assert.Equal(t, "dcr-id", clientID) + assert.Equal(t, "dcr-secret", clientSecret) +} + +func TestErrNoOAuthClient_NotMatchedByOtherErrors(t *testing.T) { + other := errors.New("finding callback port: address already in use") + assert.False(t, errors.Is(other, errNoOAuthClient)) +} + +// setStoragePath points token/client storage at a temp directory. +func setStoragePath(t *testing.T, dir string) { + t.Helper() + t.Setenv("GLEAN_AUTH_DIR", dir) +} diff --git a/internal/auth/device.go b/internal/auth/device.go new file mode 100644 index 0000000..70cc82b --- /dev/null +++ b/internal/auth/device.go @@ -0,0 +1,257 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gleanwork/glean-cli/internal/httputil" + "github.com/pkg/browser" + "golang.org/x/oauth2" +) + +const ( + defaultPollInterval = 5 * time.Second + maxPollInterval = 60 * time.Second + defaultExpiresIn = 900 // 15 minutes + maxExpiresIn = 1800 +) + +// deviceAuthResponse is the response from the device authorization endpoint (RFC 8628 §3.2). +type deviceAuthResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// deviceTokenError is the error response from the token endpoint during polling. +type deviceTokenError struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` +} + +// deviceFlowLogin performs the OAuth 2.0 Device Authorization Grant (RFC 8628). +func deviceFlowLogin(ctx context.Context, host string, disc *discoveryResult) error { + scopes := resolveScopes(disc.Provider) + deviceLog.Log("requesting device code from %s (client_id=%s)", disc.DeviceAuthEndpoint, disc.DeviceFlowClientID) + + authResp, err := requestDeviceCode(ctx, disc.DeviceAuthEndpoint, disc.DeviceFlowClientID, scopes) + if err != nil { + return fmt.Errorf("device authorization request failed: %w", err) + } + deviceLog.Log("device code received: user_code=%s verification_uri=%s expires_in=%d", authResp.UserCode, authResp.VerificationURI, authResp.ExpiresIn) + + verificationURL := authResp.VerificationURIComplete + if verificationURL == "" { + verificationURL = authResp.VerificationURI + } + + parsed, err := url.Parse(verificationURL) + if err != nil || parsed.Host == "" { + return fmt.Errorf("device authorization returned invalid verification URL: %q", verificationURL) + } + if parsed.Scheme != "https" { + return fmt.Errorf("device authorization returned non-HTTPS verification URL: %q", verificationURL) + } + + fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s\n\n", verificationURL) + if authResp.VerificationURIComplete == "" { + fmt.Printf("Then enter code: %s\n\n", authResp.UserCode) + } else { + fmt.Printf("Your code: %s\n\n", authResp.UserCode) + } + fmt.Printf("Waiting for you to complete login in the browser…\n") + + _ = browser.OpenURL(verificationURL) + + deviceLog.Log("polling token endpoint %s (interval=%ds)", disc.Endpoint.TokenURL, authResp.Interval) + token, err := pollForToken(ctx, disc.Endpoint.TokenURL, disc.DeviceFlowClientID, authResp) + if err != nil { + deviceLog.Log("device flow failed: %v", err) + return fmt.Errorf("device flow login failed: %w", err) + } + deviceLog.Log("device flow token received") + + return saveAndPrintToken(ctx, host, disc, disc.DeviceFlowClientID, token) +} + +// requestDeviceCode sends the initial device authorization request (RFC 8628 §3.1). +func requestDeviceCode(ctx context.Context, endpoint, clientID string, scopes []string) (*deviceAuthResponse, error) { + data := url.Values{ + "client_id": {clientID}, + "scope": {strings.Join(scopes, " ")}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("building device authorization request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := httputil.NewHTTPClient(10 * time.Second).Do(req) + if err != nil { + return nil, fmt.Errorf("device authorization HTTP request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + var errResp deviceTokenError + _ = json.NewDecoder(resp.Body).Decode(&errResp) + desc := errResp.ErrorDescription + if desc == "" { + desc = errResp.Error + } + if errResp.Error == "unauthorized_client" { + deviceLog.Log("IdP rejected device code grant for client %s: %s", clientID, desc) + return nil, fmt.Errorf("%s\n\nAsk your IdP administrator to add the device_code grant type\nto OAuth app %s", desc, clientID) + } + if desc != "" { + return nil, fmt.Errorf("device authorization failed: %s", desc) + } + return nil, fmt.Errorf("device authorization endpoint returned HTTP %d", resp.StatusCode) + } + + var authResp deviceAuthResponse + if err := json.NewDecoder(resp.Body).Decode(&authResp); err != nil { + return nil, fmt.Errorf("parsing device authorization response: %w", err) + } + if authResp.DeviceCode == "" { + return nil, fmt.Errorf("device authorization response missing device_code") + } + if authResp.VerificationURI == "" && authResp.VerificationURIComplete == "" { + return nil, fmt.Errorf("device authorization response missing verification_uri") + } + authResp.Interval = clampInt(authResp.Interval, int(defaultPollInterval/time.Second), int(maxPollInterval/time.Second)) + if authResp.ExpiresIn <= 0 { + authResp.ExpiresIn = defaultExpiresIn + } else if authResp.ExpiresIn > maxExpiresIn { + authResp.ExpiresIn = maxExpiresIn + } + return &authResp, nil +} + +func clampInt(v, min, max int) int { + if v < min { + return min + } + if v > max { + return max + } + return v +} + +// pollForToken polls the token endpoint until the user completes authorization (RFC 8628 §3.4–3.5). +func pollForToken(ctx context.Context, tokenURL, clientID string, authResp *deviceAuthResponse) (*oauth2.Token, error) { + interval := time.Duration(authResp.Interval) * time.Second + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + } + + if time.Now().After(deadline) { + return nil, fmt.Errorf("device code expired — run 'glean auth login' to try again") + } + + token, status, err := exchangeDeviceCode(ctx, tokenURL, clientID, authResp.DeviceCode) + if err != nil { + return nil, err + } + if status == pollSlowDown { + interval += 5 * time.Second + continue + } + if status == pollPending { + continue + } + return token, nil + } +} + +type pollStatus int + +const ( + pollDone pollStatus = iota + pollPending + pollSlowDown +) + +// exchangeDeviceCode attempts a single token exchange for a device code. +func exchangeDeviceCode(ctx context.Context, tokenURL, clientID, deviceCode string) (*oauth2.Token, pollStatus, error) { + data := url.Values{ + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + "client_id": {clientID}, + "device_code": {deviceCode}, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, pollDone, fmt.Errorf("building token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := httputil.NewHTTPClient(10 * time.Second).Do(req) + if err != nil { + return nil, pollDone, fmt.Errorf("token exchange HTTP request: %w", err) + } + defer resp.Body.Close() + + body := json.NewDecoder(resp.Body) + + if resp.StatusCode != http.StatusOK { + var tokenErr deviceTokenError + _ = body.Decode(&tokenErr) + switch tokenErr.Error { + case "authorization_pending": + return nil, pollPending, nil + case "slow_down": + return nil, pollSlowDown, nil + case "expired_token": + return nil, pollDone, fmt.Errorf("device code expired — run 'glean auth login' to try again") + case "access_denied": + return nil, pollDone, fmt.Errorf("authorization denied by user") + default: + desc := tokenErr.ErrorDescription + if desc == "" { + desc = tokenErr.Error + } + return nil, pollDone, fmt.Errorf("token request failed: %s", desc) + } + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + } + if err := body.Decode(&tokenResp); err != nil { + return nil, pollDone, fmt.Errorf("parsing token response: %w", err) + } + if tokenResp.AccessToken == "" { + return nil, pollDone, fmt.Errorf("token response missing access_token") + } + + token := &oauth2.Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + RefreshToken: tokenResp.RefreshToken, + } + if tokenResp.ExpiresIn > 0 { + token.Expiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + return token, pollDone, nil +} diff --git a/internal/auth/device_test.go b/internal/auth/device_test.go new file mode 100644 index 0000000..13974de --- /dev/null +++ b/internal/auth/device_test.go @@ -0,0 +1,294 @@ +package auth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequestDeviceCode_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + _ = r.ParseForm() + assert.Equal(t, "client-id", r.FormValue("client_id")) + assert.Contains(t, r.FormValue("scope"), "openid") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dev-code", + "user_code": "USER-1", + "verification_uri": "https://idp.example/verify", + }) + })) + defer srv.Close() + + resp, err := requestDeviceCode(context.Background(), srv.URL, "client-id", []string{"openid", "profile"}) + require.NoError(t, err) + assert.Equal(t, "dev-code", resp.DeviceCode) + assert.Equal(t, "USER-1", resp.UserCode) + assert.Equal(t, "https://idp.example/verify", resp.VerificationURI) +} + +func TestRequestDeviceCode_UnauthorizedClient(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": "unauthorized_client", + "error_description": "client cannot use this grant", + }) + })) + defer srv.Close() + + _, err := requestDeviceCode(context.Background(), srv.URL, "my-client", []string{"openid"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "client cannot use this grant") + assert.Contains(t, err.Error(), "device_code grant type") + assert.Contains(t, err.Error(), "my-client") +} + +func TestRequestDeviceCode_MissingDeviceCode(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "user_code": "U", + "verification_uri": "https://x", + }) + })) + defer srv.Close() + + _, err := requestDeviceCode(context.Background(), srv.URL, "cid", []string{"s"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing device_code") +} + +func TestRequestDeviceCode_IntervalAndExpiresIn(t *testing.T) { + cases := []struct { + name string + rawInterval int + rawExpiresIn int + wantInterval int + wantExpiresIn int + }{ + {"interval_below_min_clamped_to_5", 1, 100, 5, 100}, + {"interval_above_max_clamped_to_60", 100, 100, 60, 100}, + {"interval_in_range_unchanged", 30, 100, 30, 100}, + {"expires_in_zero_defaults_to_900", 5, 0, 5, defaultExpiresIn}, + {"expires_in_capped_at_1800", 5, 99999, 5, maxExpiresIn}, + {"expires_in_in_range_unchanged", 5, 600, 5, 600}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dc", + "verification_uri": "https://v", + "interval": tc.rawInterval, + "expires_in": tc.rawExpiresIn, + }) + })) + defer srv.Close() + + resp, err := requestDeviceCode(context.Background(), srv.URL, "cid", []string{"s"}) + require.NoError(t, err) + assert.Equal(t, tc.wantInterval, resp.Interval) + assert.Equal(t, tc.wantExpiresIn, resp.ExpiresIn) + }) + } +} + +func TestRequestDeviceCode_MissingVerificationURI(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "device_code": "dc", + }) + })) + defer srv.Close() + + _, err := requestDeviceCode(context.Background(), srv.URL, "cid", []string{"s"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing verification_uri") +} + +func TestExchangeDeviceCode_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + _ = r.ParseForm() + assert.Equal(t, "urn:ietf:params:oauth:grant-type:device_code", r.FormValue("grant_type")) + assert.Equal(t, "cid", r.FormValue("client_id")) + assert.Equal(t, "dev", r.FormValue("device_code")) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "tok-123", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer srv.Close() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.NoError(t, err) + assert.Equal(t, pollDone, status) + require.NotNil(t, tok) + assert.Equal(t, "tok-123", tok.AccessToken) + assert.Equal(t, "Bearer", tok.TokenType) +} + +func TestExchangeDeviceCode_AuthorizationPending(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + })) + defer srv.Close() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.NoError(t, err) + assert.Nil(t, tok) + assert.Equal(t, pollPending, status) +} + +func TestExchangeDeviceCode_SlowDown(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "slow_down"}) + })) + defer srv.Close() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.NoError(t, err) + assert.Nil(t, tok) + assert.Equal(t, pollSlowDown, status) +} + +func TestExchangeDeviceCode_AccessDenied(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "access_denied"}) + })) + defer srv.Close() + + tok, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.Error(t, err) + assert.Contains(t, err.Error(), "authorization denied") + assert.Nil(t, tok) + assert.Equal(t, pollDone, status) +} + +func TestExchangeDeviceCode_EmptyAccessToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{ + "access_token": "", + "token_type": "Bearer", + }) + })) + defer srv.Close() + + _, status, err := exchangeDeviceCode(context.Background(), srv.URL, "cid", "dev") + require.Error(t, err) + assert.Contains(t, err.Error(), "missing access_token") + assert.Equal(t, pollDone, status) +} + +func TestPollForToken_PendingThenSuccess(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := calls.Add(1) + if n == 1 { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + return + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "final", + "token_type": "Bearer", + }) + })) + defer srv.Close() + + auth := &deviceAuthResponse{ + DeviceCode: "dc", + Interval: 0, + ExpiresIn: 60, + VerificationURI: "https://v", + } + + tok, err := pollForToken(context.Background(), srv.URL, "cid", auth) + require.NoError(t, err) + require.NotNil(t, tok) + assert.Equal(t, "final", tok.AccessToken) + assert.Equal(t, int32(2), calls.Load()) +} + +func TestPollForToken_SlowDownIncreasesInterval(t *testing.T) { + if testing.Short() { + t.Skip("timing-based test waits ~5s after slow_down") + } + var calls atomic.Int32 + var firstAt, secondAt atomic.Int64 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := calls.Add(1) + now := time.Now().UnixNano() + if n == 1 { + firstAt.Store(now) + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "slow_down"}) + return + } + secondAt.Store(now) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "ok", + "token_type": "Bearer", + }) + })) + defer srv.Close() + + auth := &deviceAuthResponse{ + DeviceCode: "dc", + Interval: 0, + ExpiresIn: 120, + VerificationURI: "https://v", + } + + start := time.Now() + tok, err := pollForToken(context.Background(), srv.URL, "cid", auth) + elapsed := time.Since(start) + + require.NoError(t, err) + require.NotNil(t, tok) + assert.Equal(t, "ok", tok.AccessToken) + assert.Equal(t, int32(2), calls.Load()) + + gap := time.Duration(secondAt.Load() - firstAt.Load()) + assert.GreaterOrEqual(t, gap, 4*time.Second, + "expected ~5s wait after slow_down increased interval; gap=%v", gap) + assert.GreaterOrEqual(t, elapsed, 4*time.Second) +} + +func TestPollForToken_ContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + auth := &deviceAuthResponse{ + DeviceCode: "dc", + Interval: 10, + ExpiresIn: 3600, + VerificationURI: "https://v", + } + + _, err := pollForToken(ctx, "http://unused.example/token", "cid", auth) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} diff --git a/internal/auth/discovery.go b/internal/auth/discovery.go index b512e17..e467410 100644 --- a/internal/auth/discovery.go +++ b/internal/auth/discovery.go @@ -22,8 +22,9 @@ func (e *ErrOAuthNotSupported) Error() string { } type protectedResourceMetadata struct { - Resource string `json:"resource"` - AuthorizationServers []string `json:"authorization_servers"` + Resource string `json:"resource"` + AuthorizationServers []string `json:"authorization_servers"` + GleanDeviceFlowClientID string `json:"glean_device_flow_client_id,omitempty"` } // fetchProtectedResource fetches RFC 9728 protected resource metadata. diff --git a/internal/auth/discovery_test.go b/internal/auth/discovery_test.go index fac1af6..41eee4a 100644 --- a/internal/auth/discovery_test.go +++ b/internal/auth/discovery_test.go @@ -27,6 +27,22 @@ func TestFetchProtectedResource_Success(t *testing.T) { assert.Equal(t, []string{"https://auth.example.com"}, result.AuthorizationServers) } +func TestFetchProtectedResource_DeviceFlowClientID(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/.well-known/oauth-protected-resource", r.URL.Path) + json.NewEncoder(w).Encode(map[string]any{ + "resource": "https://example.glean.com", + "authorization_servers": []string{"https://auth.example.com"}, + "glean_device_flow_client_id": "device-flow-client-123", + }) + })) + defer srv.Close() + + result, err := fetchProtectedResource(context.Background(), srv.URL) + require.NoError(t, err) + assert.Equal(t, "device-flow-client-123", result.GleanDeviceFlowClientID) +} + func TestFetchProtectedResource_NotFound(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) diff --git a/snippets/readme/snippet-04.sh b/snippets/readme/snippet-04.sh index 029400d..5de97c3 100644 --- a/snippets/readme/snippet-04.sh +++ b/snippets/readme/snippet-04.sh @@ -1,3 +1,3 @@ -glean auth login # opens browser, completes PKCE flow +glean auth login # interactive login (detects the best method automatically) glean auth status # verify credentials, host, and token expiry glean auth logout # remove all stored credentials