diff --git a/charts/Chart.yaml b/charts/Chart.yaml index ab46db97..50903267 100644 --- a/charts/Chart.yaml +++ b/charts/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: hyperfleet-api description: HyperFleet API - Cluster Lifecycle Management Service type: application -version: 1.0.0 +version: 1.1.0 appVersion: "0.0.0-dev" maintainers: - name: HyperFleet Team diff --git a/charts/templates/NOTES.txt b/charts/templates/NOTES.txt index e7747aa4..185568c7 100644 --- a/charts/templates/NOTES.txt +++ b/charts/templates/NOTES.txt @@ -17,5 +17,25 @@ Validation schema validation is ENABLED. The API will fail to start if the schema is missing or invalid. {{- end }} +Caller identity (audit attribution): + config.server.jwt.identity_claim: {{ .Values.config.server.jwt.identity_claim | default "" | quote }} + config.server.identity_header: {{ .Values.config.server.identity_header | default "" | quote }} + +When identity_header is set and a request includes a non-empty header with that name, +the header value overrides the JWT claim for audit fields (created_by, updated_by). +When identity_claim is set, the named JWT claim is used as the caller identity. +Only trusted gateways should set the identity header in production. + +Override in values.yaml: + config: + server: + jwt: + identity_claim: preferred_username + identity_header: X-HyperFleet-Identity + +Or at install/upgrade: + --set config.server.jwt.identity_claim=preferred_username + --set config.server.identity_header=X-HyperFleet-Identity + Documentation: https://github.com/openshift-hyperfleet/hyperfleet-api/blob/main/docs/deployment.md diff --git a/charts/templates/configmap.yaml b/charts/templates/configmap.yaml index 7b37626a..7b18ce13 100644 --- a/charts/templates/configmap.yaml +++ b/charts/templates/configmap.yaml @@ -30,6 +30,9 @@ data: enabled: {{ .Values.config.server.jwt.enabled }} issuer_url: {{ .Values.config.server.jwt.issuer_url | quote }} audience: {{ .Values.config.server.jwt.audience | quote }} + identity_claim: {{ .Values.config.server.jwt.identity_claim | quote }} + + identity_header: {{ .Values.config.server.identity_header | quote }} jwk: cert_file: {{ .Values.config.server.jwk.cert_file | quote }} diff --git a/charts/values.yaml b/charts/values.yaml index c5501547..755a4862 100644 --- a/charts/values.yaml +++ b/charts/values.yaml @@ -52,6 +52,9 @@ config: enabled: false issuer_url: "" audience: "" + identity_claim: email + + identity_header: "" jwk: cert_file: "" @@ -104,6 +107,7 @@ config: - Cookie - X-Auth-Token - X-Forwarded-Authorization + - X-HyperFleet-Identity fields: - password - secret diff --git a/cmd/hyperfleet-api/environments/types.go b/cmd/hyperfleet-api/environments/types.go index cf267930..35e71396 100755 --- a/cmd/hyperfleet-api/environments/types.go +++ b/cmd/hyperfleet-api/environments/types.go @@ -3,7 +3,6 @@ package environments import ( "sync" - "github.com/openshift-hyperfleet/hyperfleet-api/pkg/auth" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/config" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/db" ) @@ -37,9 +36,7 @@ type Database struct { SessionFactory db.SessionFactory } -type Handlers struct { - AuthMiddleware auth.JWTMiddleware -} +type Handlers struct{} type Services struct { serviceRegistry map[string]interface{} diff --git a/cmd/hyperfleet-api/server/routes.go b/cmd/hyperfleet-api/server/routes.go index 4a35bbdd..7f961e5a 100755 --- a/cmd/hyperfleet-api/server/routes.go +++ b/cmd/hyperfleet-api/server/routes.go @@ -25,7 +25,6 @@ type ServicesInterface interface { type RouteRegistrationFunc func( apiV1Router *mux.Router, services ServicesInterface, - authMiddleware auth.JWTMiddleware, ) var routeRegistry = make(map[string]RouteRegistrationFunc) @@ -40,10 +39,9 @@ func RegisterRoutes(name string, registrationFunc RouteRegistrationFunc) { func LoadDiscoveredRoutes( apiV1Router *mux.Router, services ServicesInterface, - authMiddleware auth.JWTMiddleware, ) { for name, registrationFunc := range routeRegistry { - registrationFunc(apiV1Router, services, authMiddleware) + registrationFunc(apiV1Router, services) _ = name // prevent unused variable warning } } @@ -53,17 +51,6 @@ func (s *apiServer) routes(tracingEnabled bool) *mux.Router { metadataHandler := handlers.NewMetadataHandler() - var authMiddleware auth.JWTMiddleware - authMiddleware = &auth.MiddlewareMock{} - if env().Config.Server.JWT.Enabled { - var err error - authMiddleware, err = auth.NewAuthMiddleware() - check(err, "Unable to create auth middleware") - } - if authMiddleware == nil { - check(fmt.Errorf("auth middleware is nil"), "Unable to create auth middleware: missing middleware") - } - // mainRouter is top level "/" mainRouter := mux.NewRouter() mainRouter.NotFoundHandler = http.HandlerFunc(api.SendNotFound) @@ -99,8 +86,20 @@ func (s *apiServer) routes(tracingEnabled bool) *mux.Router { err = registerAPIMiddleware(apiV1Router) check(err, "Failed to initialize API middleware") + identityCfg := auth.CallerIdentityConfig{ + HeaderName: env().Config.Server.IdentityHeader, + } + if env().Config.Server.JWT.Enabled && env().Config.Server.JWT.IdentityClaim != "" { + identityCfg.JWTIdentityClaim = env().Config.Server.JWT.IdentityClaim + } + if identityCfg.JWTIdentityClaim != "" || identityCfg.HeaderName != "" { + callerIdentityMW, mwErr := auth.NewCallerIdentityMiddleware(identityCfg) + check(mwErr, "Unable to create caller identity middleware") + apiV1Router.Use(callerIdentityMW.ResolveCallerIdentity) + } + // Auto-discovered routes (no manual editing needed) - LoadDiscoveredRoutes(apiV1Router, services, authMiddleware) + LoadDiscoveredRoutes(apiV1Router, services) return mainRouter } diff --git a/configs/config.yaml.example b/configs/config.yaml.example index b8247d17..a1266703 100644 --- a/configs/config.yaml.example +++ b/configs/config.yaml.example @@ -21,6 +21,9 @@ server: enabled: true # Enable JWT authentication issuer_url: "" # JWT issuer URL (required when jwt.enabled=true) audience: "" # JWT audience claim (optional) + identity_claim: email # JWT claim used as request identity for audit (e.g. email, preferred_username, sub) + + identity_header: "" # HTTP header name for caller identity; leave empty to disable (e.g. X-HyperFleet-Identity) jwk: cert_file: "" # JWK certificate file path diff --git a/docs/authentication.md b/docs/authentication.md index a39b6fb1..6fbc50d5 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -4,10 +4,11 @@ This document describes authentication mechanisms for the HyperFleet API. ## Overview -HyperFleet API supports two authentication modes: +HyperFleet API supports the following authentication modes: -1. **Development Mode (No Auth)**: For local development and testing -2. **Production Mode (JWT Auth)**: JWT-based authentication with configurable issuer +1. **Development Mode (No Auth)**: For local development and testing without authentication +2. **Development with JWT (Google Cloud)**: Local development with real JWT validation using Google identity tokens +3. **Production Mode (JWT Auth)**: JWT-based authentication with configurable issuer ## Development Mode (No Auth) @@ -30,8 +31,99 @@ export HYPERFLEET_SERVER_JWT_ENABLED=false ./bin/hyperfleet-api serve ``` +### Caller identity in development mode + +When JWT is disabled and no `identity_header` is configured, caller identity resolution is inactive. Audit fields (`created_by`, `updated_by`, `deleted_by`) fall back to `system@hyperfleet.local`. + +To get proper caller attribution without JWT, configure an identity header: + +```bash +./bin/hyperfleet-api serve \ + --server-jwt-enabled=false \ + --server-identity-header=X-HyperFleet-Identity +``` + +Then pass the header in requests: + +```bash +# Create with attribution +curl -X POST http://localhost:8000/api/hyperfleet/v1/clusters \ + -H "Content-Type: application/json" \ + -H "X-HyperFleet-Identity: dev-user@local" \ + -d '{"kind":"Cluster","name":"my-cluster","spec":{}}' + +# Read requests work without the header +curl http://localhost:8000/api/hyperfleet/v1/clusters | jq +``` + +When `identity_header` or `identity_claim` is configured, mutating requests (POST, PATCH, PUT, DELETE) that cannot resolve a caller identity are rejected with `401 Unauthorized`. Read requests (GET, LIST) are allowed without identity. + **Important**: Never disable authentication in production environments. +## Development with JWT (Google Cloud) + +For local development with real JWT validation, you can use Google Cloud identity tokens. This gives you proper authentication and caller identity without deploying a dedicated identity provider. + +### Prerequisites + +- [Google Cloud SDK](https://cloud.google.com/sdk/docs/install) installed +- Authenticated with `gcloud auth login` + +### Start the server + +```bash +./bin/hyperfleet-api serve \ + --server-jwt-enabled=true \ + --server-jwt-issuer-url="https://accounts.google.com" \ + --server-jwk-cert-url="https://www.googleapis.com/oauth2/v3/certs" \ + --server-jwt-audience="32555940559.apps.googleusercontent.com" \ + --server-jwt-identity-claim="email" \ + --server-identity-header=X-HyperFleet-Identity \ + --db-host localhost --db-port 5432 --db-name hyperfleet --db-username hyperfleet +``` + +The audience `32555940559.apps.googleusercontent.com` is the default gcloud CLI OAuth client ID. It matches the `aud` claim in tokens generated by `gcloud auth print-identity-token`. + +### Generate a token and make requests + +```bash +# Generate an identity token (valid for ~1 hour) +TOKEN=$(gcloud auth print-identity-token) + +# List clusters +curl -H "Authorization: Bearer $TOKEN" \ + http://localhost:8000/api/hyperfleet/v1/clusters | jq + +# Create a cluster (created_by will be your Google email) +curl -X POST http://localhost:8000/api/hyperfleet/v1/clusters \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -d '{"kind":"Cluster","name":"my-cluster","spec":{}}' + +# Override identity via header (header takes precedence over JWT) +curl -X POST http://localhost:8000/api/hyperfleet/v1/clusters \ + -H "Authorization: Bearer $TOKEN" \ + -H "Content-Type: application/json" \ + -H "X-HyperFleet-Identity: gateway-user@corp.com" \ + -d '{"kind":"Cluster","name":"my-cluster-2","spec":{}}' +``` + +### How it works + +Google identity tokens are standard OIDC JWTs signed by Google's keys. The server validates them like any other JWT: + +1. Fetches Google's public keys from the `jwk_cert_url` +2. Verifies the RS256 signature +3. Checks `iss` matches `https://accounts.google.com` +4. Checks `aud` matches the configured audience +5. Extracts the `email` claim for caller identity + +You can inspect your token with: + +```bash +gcloud auth print-identity-token | cut -d. -f2 | base64 -d 2>/dev/null | jq +``` + ## Production Mode (JWT Auth) Production deployments use JWT-based authentication with a configurable issuer. @@ -69,6 +161,46 @@ curl -H "Authorization: Bearer ${TOKEN}" \ http://localhost:8000/api/hyperfleet/v1/clusters ``` +## Caller identity for audit + +Authentication (JWT validation) and caller identity (audit attribution) are separate concerns. Identity resolution is enabled by setting `identity_claim` (in the JWT config) and/or `identity_header`. When neither is set, no identity middleware is registered and audit fields fall back to `system@hyperfleet.local`. + +| Layer | Component | Responsibility | +|-------|-----------|----------------| +| Outer | `JWTHandler` | Validates `Authorization: Bearer` token | +| Inner | `ResolveCallerIdentity` middleware | Resolves who is recorded as the actor | + +The resolved identity is written to `created_by` on create, `updated_by` on update, and `deleted_by` on delete. Precedence: identity header > JWT claim. + +When identity resolution is configured, mutating requests (POST, PATCH, PUT, DELETE) that cannot resolve a caller identity are rejected with `401 Unauthorized`. Read requests (GET, LIST) are allowed without identity. + +### JWT claim + +Configure which JWT claim is used as the caller identity: + +```yaml +server: + jwt: + identity_claim: email # or preferred_username, sub, etc. +``` + +### HTTP identity header (optional) + +When set, a trusted gateway can set the caller identity via HTTP header. **If the header is present and non-empty, it overrides the JWT claim** for audit fields. JWT validation is still required when `jwt.enabled=true`. + +```yaml +server: + identity_header: X-HyperFleet-Identity +``` + +**Security:** Clients must not be able to set this header directly. Configure your ingress/gateway to strip the header from external requests and set it from the authenticated upstream user. + +```bash +export HYPERFLEET_SERVER_IDENTITY_HEADER=X-HyperFleet-Identity +``` + +Identity values from both sources are validated: trimmed of whitespace, limited to 256 characters, and rejected if they contain control characters. + ## Configuration ### Environment Variables diff --git a/docs/config.md b/docs/config.md index aff0003a..88b5416d 100644 --- a/docs/config.md +++ b/docs/config.md @@ -257,6 +257,8 @@ HTTP server settings for the API endpoint. | `server.jwt.enabled` | bool | `true` | Enable JWT authentication | | `server.jwt.issuer_url` | string | `""` | Expected JWT issuer URL for token validation (required when JWT is enabled) | | `server.jwt.audience` | string | `""` | Expected JWT audience claim (optional) | +| `server.jwt.identity_claim` | string | `email` | JWT claim used as request identity for audit (e.g. `email`, `preferred_username`, `sub`) | +| `server.identity_header` | string | `""` | HTTP header name for caller identity; when set and non-empty, overrides JWT claim for audit attribution | | `server.jwk.cert_file` | string | `""` | JWK certificate file path (optional) | | `server.jwk.cert_url` | string | `""` | JWK certificate URL (required when JWT is enabled and cert_file is not set) | @@ -351,6 +353,8 @@ Complete table of all configuration properties, their environment variables, and | `server.jwt.enabled` | `HYPERFLEET_SERVER_JWT_ENABLED` | bool | `true` | | `server.jwt.issuer_url` | `HYPERFLEET_SERVER_JWT_ISSUER_URL` | string | `""` | | `server.jwt.audience` | `HYPERFLEET_SERVER_JWT_AUDIENCE` | string | `""` | +| `server.jwt.identity_claim` | `HYPERFLEET_SERVER_JWT_IDENTITY_CLAIM` | string | `email` | +| `server.identity_header` | `HYPERFLEET_SERVER_IDENTITY_HEADER` | string | `""` | | `server.jwk.cert_file` | `HYPERFLEET_SERVER_JWK_CERT_FILE` | string | `""` | | `server.jwk.cert_url` | `HYPERFLEET_SERVER_JWK_CERT_URL` | string | `""` | | **Database** | | | | @@ -413,6 +417,8 @@ All CLI flags and their corresponding configuration paths. | `--server-jwt-enabled` | `server.jwt.enabled` | bool | | `--server-jwt-issuer-url` | `server.jwt.issuer_url` | string | | `--server-jwt-audience` | `server.jwt.audience` | string | +| `--server-jwt-identity-claim` | `server.jwt.identity_claim` | string | +| `--server-identity-header` | `server.identity_header` | string | | `--server-jwk-cert-file` | `server.jwk.cert_file` | string | | `--server-jwk-cert-url` | `server.jwk.cert_url` | string | | **Database** | | | diff --git a/pkg/auth/auth_middleware.go b/pkg/auth/auth_middleware.go index 47302300..0e2d2338 100755 --- a/pkg/auth/auth_middleware.go +++ b/pkg/auth/auth_middleware.go @@ -5,37 +5,56 @@ import ( "net/http" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/errors" + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/validation" ) -type JWTMiddleware interface { - AuthenticateAccountJWT(next http.Handler) http.Handler +// CallerIdentityMiddleware resolves and attaches the caller identity used for audit fields. +type CallerIdentityMiddleware interface { + ResolveCallerIdentity(next http.Handler) http.Handler } -type Middleware struct{} +type callerIdentityMiddleware struct { + cfg CallerIdentityConfig +} -var _ JWTMiddleware = &Middleware{} +var _ CallerIdentityMiddleware = &callerIdentityMiddleware{} -func NewAuthMiddleware() (*Middleware, error) { - middleware := Middleware{} - return &middleware, nil +func NewCallerIdentityMiddleware(cfg CallerIdentityConfig) (CallerIdentityMiddleware, error) { + if cfg.HeaderName != "" { + if validation.IsForbiddenIdentityHeaderName(cfg.HeaderName) { + return nil, fmt.Errorf("identity header name %q is not allowed", cfg.HeaderName) + } + } + return &callerIdentityMiddleware{cfg: cfg}, nil } -// AuthenticateAccountJWT Middleware handler to validate JWT tokens and authenticate users -func (a *Middleware) AuthenticateAccountJWT(next http.Handler) http.Handler { +// ResolveCallerIdentity attaches the resolved caller identity to the request context. +// JWT validation is performed by JWTHandler; this middleware only resolves attribution. +func (m *callerIdentityMiddleware) ResolveCallerIdentity(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if shouldSkipCallerIdentity(r.URL.Path) { + next.ServeHTTP(w, r) + return + } + ctx := r.Context() - payload, err := GetAuthPayload(r) - if err != nil { - handleError( - ctx, w, r, errors.CodeAuthNoCredentials, - fmt.Sprintf("Unable to get payload details from JWT token: %s", err), - ) + identity, err := CallerIdentityFromRequest(ctx, r, m.cfg) + + if identity != "" { + ctx = SetUsernameContext(ctx, identity) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) return } - // Append the username to the request context - ctx = SetUsernameContext(ctx, payload.Username) - *r = *r.WithContext(ctx) + if isMutatingMethod(r.Method) { + msg := "Caller identity is required for mutating requests but could not be resolved" + if err != nil { + msg = fmt.Sprintf("Unable to resolve caller identity: %s", err) + } + handleError(ctx, w, r, errors.CodeAuthNoCredentials, msg) + return + } next.ServeHTTP(w, r) }) diff --git a/pkg/auth/auth_middleware_mock.go b/pkg/auth/auth_middleware_mock.go deleted file mode 100755 index 5deef8e6..00000000 --- a/pkg/auth/auth_middleware_mock.go +++ /dev/null @@ -1,16 +0,0 @@ -package auth - -import ( - "net/http" -) - -type MiddlewareMock struct{} - -var _ JWTMiddleware = &MiddlewareMock{} - -func (a *MiddlewareMock) AuthenticateAccountJWT(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // TODO need to append a username to the request context - next.ServeHTTP(w, r) - }) -} diff --git a/pkg/auth/context.go b/pkg/auth/context.go index 2fbcf409..fc466174 100755 --- a/pkg/auth/context.go +++ b/pkg/auth/context.go @@ -14,6 +14,9 @@ type contextKey string const ( ContextUsernameKey contextKey = "username" ContextJWTTokenKey contextKey = "jwt_token" + + // DefaultJWTIdentityClaim is used when server.jwt.identity_claim is unset. + DefaultJWTIdentityClaim = "email" ) // Payload defines the structure of the JWT payload we expect @@ -117,6 +120,70 @@ func GetAuthPayloadFromContext(ctx context.Context) (*Payload, error) { return payload, nil } +// GetIdentityFromContext returns the configured JWT claim value used as the request identity. +func GetIdentityFromContext(ctx context.Context, identityClaim string) (string, error) { + if identityClaim == "" { + identityClaim = DefaultJWTIdentityClaim + } + + userToken := GetJWTTokenFromContext(ctx) + if userToken == nil { + return "", fmt.Errorf("JWT token in context is nil, unauthorized") + } + + claims, ok := userToken.Claims.(jwt.MapClaims) + if !ok { + return "", fmt.Errorf("unable to parse JWT token claims: unexpected type %T", userToken.Claims) + } + + if identity, ok := stringClaim(claims, identityClaim); ok { + return identity, nil + } + + payload, err := GetAuthPayloadFromContext(ctx) + if err != nil { + return "", fmt.Errorf("identity claim %q not found: %w", identityClaim, err) + } + + switch identityClaim { + case "email": + if payload.Email != "" { + return payload.Email, nil + } + case "username", "preferred_username": + if payload.Username != "" { + return payload.Username, nil + } + case "first_name", "given_name": + if payload.FirstName != "" { + return payload.FirstName, nil + } + case "last_name", "family_name": + if payload.LastName != "" { + return payload.LastName, nil + } + case "clientId": + if payload.ClientID != "" { + return payload.ClientID, nil + } + case "iss": + if payload.Issuer != "" { + return payload.Issuer, nil + } + } + + return "", fmt.Errorf("identity claim %q not found or empty", identityClaim) +} + +func stringClaim(claims jwt.MapClaims, key string) (string, bool) { + val, ok := claims[key] + if !ok { + return "", false + } + s, ok := val.(string) + return s, ok && s != "" +} + func GetAuthPayload(r *http.Request) (*Payload, error) { return GetAuthPayloadFromContext(r.Context()) } diff --git a/pkg/auth/context_test.go b/pkg/auth/context_test.go new file mode 100644 index 00000000..0ea9a0da --- /dev/null +++ b/pkg/auth/context_test.go @@ -0,0 +1,72 @@ +package auth + +import ( + "context" + "testing" + + "github.com/golang-jwt/jwt/v5" + . "github.com/onsi/gomega" +) + +func TestGetIdentityFromContext(t *testing.T) { + tests := []struct { + claims jwt.MapClaims + name string + identityField string + want string + errSubstring string + wantErr bool + }{ + { + name: "reads configured claim directly", + claims: jwt.MapClaims{ + "email": "user@example.com", + "sub": "subject-id", + }, + identityField: "sub", + want: "subject-id", + }, + { + name: "defaults to email when field is empty", + claims: jwt.MapClaims{"email": "user@example.com"}, + identityField: "", + want: "user@example.com", + }, + { + name: "falls back to preferred_username via payload normalization", + claims: jwt.MapClaims{"preferred_username": "jdoe"}, + identityField: "preferred_username", + want: "jdoe", + }, + { + name: "returns error when claim is missing", + claims: jwt.MapClaims{"email": "user@example.com"}, + identityField: "missing_claim", + wantErr: true, + errSubstring: "missing_claim", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + RegisterTestingT(t) + ctx := contextWithClaims(tc.claims) + + identity, err := GetIdentityFromContext(ctx, tc.identityField) + if tc.wantErr { + Expect(err).To(HaveOccurred()) + if tc.errSubstring != "" { + Expect(err.Error()).To(ContainSubstring(tc.errSubstring)) + } + return + } + Expect(err).NotTo(HaveOccurred()) + Expect(identity).To(Equal(tc.want)) + }) + } +} + +func contextWithClaims(claims jwt.MapClaims) context.Context { + token := &jwt.Token{Claims: claims} + return SetJWTTokenContext(context.Background(), token) +} diff --git a/pkg/auth/identity.go b/pkg/auth/identity.go new file mode 100644 index 00000000..d5a17140 --- /dev/null +++ b/pkg/auth/identity.go @@ -0,0 +1,77 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" +) + +const maxCallerIdentityLen = 256 + +// CallerIdentityConfig controls how the caller identity is resolved for audit fields. +// Identity resolution is enabled by setting the relevant fields: +// - HeaderName: when non-empty, the named HTTP header is checked first +// - JWTIdentityClaim: when non-empty, the JWT claim is used as fallback (or primary when no header is configured) +type CallerIdentityConfig struct { + JWTIdentityClaim string + HeaderName string +} + +// CallerIdentityFromRequest resolves the caller identity with header-primary precedence. +// When the identity header is configured and present, it overrides the JWT claim. +// Both header and JWT identity values are normalized: trimmed, length-checked, and +// validated for control characters before being accepted. +func CallerIdentityFromRequest(ctx context.Context, r *http.Request, cfg CallerIdentityConfig) (string, error) { + if cfg.HeaderName != "" { + raw := r.Header.Get(cfg.HeaderName) + if raw != "" { + identity, err := normalizeIdentity(raw, "identity header") + if err != nil { + return "", err + } + if identity != "" { + return identity, nil + } + } + } + + if cfg.JWTIdentityClaim != "" { + raw, err := GetIdentityFromContext(ctx, cfg.JWTIdentityClaim) + if err != nil { + return "", err + } + return normalizeIdentity(raw, fmt.Sprintf("JWT claim %q", cfg.JWTIdentityClaim)) + } + + return "", nil +} + +// normalizeIdentity trims, length-checks, and validates a caller identity value +// regardless of source (HTTP header or JWT claim). The source parameter is used +// in error messages to distinguish the origin. +func normalizeIdentity(raw string, source string) (string, error) { + value := strings.TrimSpace(raw) + if value == "" { + return "", nil + } + if len(value) > maxCallerIdentityLen { + return "", fmt.Errorf("%s value exceeds maximum length %d", source, maxCallerIdentityLen) + } + for _, r := range value { + if r < 0x20 || r == 0x7f { + return "", fmt.Errorf("%s value contains invalid characters", source) + } + } + return value, nil +} + +func shouldSkipCallerIdentity(path string) bool { + return strings.HasPrefix(path, "/api/hyperfleet/v1/openapi") || + strings.HasPrefix(path, "/api/hyperfleet/v1/errors") +} + +func isMutatingMethod(method string) bool { + return method == http.MethodPost || method == http.MethodPatch || method == http.MethodDelete || + method == http.MethodPut +} diff --git a/pkg/auth/identity_test.go b/pkg/auth/identity_test.go new file mode 100644 index 00000000..9d6c7a7c --- /dev/null +++ b/pkg/auth/identity_test.go @@ -0,0 +1,379 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/golang-jwt/jwt/v5" + . "github.com/onsi/gomega" +) + +func TestCallerIdentityFromRequest(t *testing.T) { + tests := []struct { + claims jwt.MapClaims + name string + headerValue string + want string + cfg CallerIdentityConfig + setHeader bool + wantErr bool + }{ + { + name: "header overrides JWT claim", + claims: jwt.MapClaims{"email": "jwt@example.com"}, + setHeader: true, + headerValue: "gateway-user@example.com", + cfg: CallerIdentityConfig{ + JWTIdentityClaim: "email", + HeaderName: "X-HyperFleet-Identity", + }, + want: "gateway-user@example.com", + }, + { + name: "falls back to JWT when header absent", + claims: jwt.MapClaims{"email": "jwt@example.com"}, + cfg: CallerIdentityConfig{ + JWTIdentityClaim: "email", + HeaderName: "X-HyperFleet-Identity", + }, + want: "jwt@example.com", + }, + { + name: "rejects invalid header value", + setHeader: true, + headerValue: "bad\x00value", + cfg: CallerIdentityConfig{ + HeaderName: "X-HyperFleet-Identity", + }, + wantErr: true, + }, + { + name: "rejects header value exceeding max length", + setHeader: true, + headerValue: strings.Repeat("a", maxCallerIdentityLen+1), + cfg: CallerIdentityConfig{ + HeaderName: "X-HyperFleet-Identity", + }, + wantErr: true, + }, + { + name: "header only without JWT claim", + setHeader: true, + headerValue: "dev-user", + cfg: CallerIdentityConfig{ + HeaderName: "X-HyperFleet-Identity", + }, + want: "dev-user", + }, + { + name: "no resolution when nothing configured", + claims: jwt.MapClaims{"email": "jwt@example.com"}, + cfg: CallerIdentityConfig{}, + want: "", + }, + { + name: "empty header falls back to JWT", + claims: jwt.MapClaims{"email": "jwt@example.com"}, + setHeader: true, + headerValue: "", + cfg: CallerIdentityConfig{ + JWTIdentityClaim: "email", + HeaderName: "X-HyperFleet-Identity", + }, + want: "jwt@example.com", + }, + { + name: "whitespace-only header falls back to JWT", + claims: jwt.MapClaims{"email": "jwt@example.com"}, + setHeader: true, + headerValue: " ", + cfg: CallerIdentityConfig{ + JWTIdentityClaim: "email", + HeaderName: "X-HyperFleet-Identity", + }, + want: "jwt@example.com", + }, + { + name: "header at exact max length accepted", + setHeader: true, + headerValue: strings.Repeat("a", maxCallerIdentityLen), + cfg: CallerIdentityConfig{ + HeaderName: "X-HyperFleet-Identity", + }, + want: strings.Repeat("a", maxCallerIdentityLen), + }, + { + name: "rejects oversized JWT claim value", + claims: jwt.MapClaims{"email": strings.Repeat("x", maxCallerIdentityLen+1)}, + cfg: CallerIdentityConfig{ + JWTIdentityClaim: "email", + }, + wantErr: true, + }, + { + name: "trims whitespace from JWT claim value", + claims: jwt.MapClaims{"email": " user@example.com "}, + cfg: CallerIdentityConfig{ + JWTIdentityClaim: "email", + }, + want: "user@example.com", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + RegisterTestingT(t) + r := httptest.NewRequest(http.MethodGet, "/api/hyperfleet/v1/clusters", nil) + if tc.claims != nil { + r = r.WithContext(contextWithClaims(tc.claims)) + } + if tc.setHeader { + headerName := tc.cfg.HeaderName + if headerName == "" { + headerName = "X-HyperFleet-Identity" + } + r.Header.Set(headerName, tc.headerValue) + } + + identity, err := CallerIdentityFromRequest(r.Context(), r, tc.cfg) + if tc.wantErr { + Expect(err).To(HaveOccurred()) + return + } + Expect(err).NotTo(HaveOccurred()) + Expect(identity).To(Equal(tc.want)) + }) + } +} + +func TestNewCallerIdentityMiddleware(t *testing.T) { + RegisterTestingT(t) + + t.Run("rejects forbidden header name", func(t *testing.T) { + RegisterTestingT(t) + _, err := NewCallerIdentityMiddleware(CallerIdentityConfig{ + HeaderName: "Authorization", + }) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not allowed")) + }) + + t.Run("returns middleware when header validation passes", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{ + HeaderName: "X-HyperFleet-Identity", + }) + Expect(err).NotTo(HaveOccurred()) + Expect(mw).NotTo(BeNil()) + }) + + t.Run("returns middleware with no config", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(mw).NotTo(BeNil()) + }) +} + +func TestResolveCallerIdentityMiddleware(t *testing.T) { + RegisterTestingT(t) + + t.Run("skips openapi paths", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{JWTIdentityClaim: "email"}) + Expect(err).NotTo(HaveOccurred()) + + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + Expect(GetUsernameFromContext(r.Context())).To(BeEmpty()) + }) + + r := httptest.NewRequest(http.MethodGet, "/api/hyperfleet/v1/openapi", nil) + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(called).To(BeTrue()) + Expect(w.Code).To(Equal(http.StatusOK)) + }) + + t.Run("allows GET without identity", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{JWTIdentityClaim: "email"}) + Expect(err).NotTo(HaveOccurred()) + + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + Expect(GetUsernameFromContext(r.Context())).To(BeEmpty()) + }) + + r := httptest.NewRequest(http.MethodGet, "/api/hyperfleet/v1/clusters", nil) + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(called).To(BeTrue()) + Expect(w.Code).To(Equal(http.StatusOK)) + }) + + t.Run("returns 401 on POST without identity", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{JWTIdentityClaim: "email"}) + Expect(err).NotTo(HaveOccurred()) + + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + nextCalled = true + }) + + r := httptest.NewRequest(http.MethodPost, "/api/hyperfleet/v1/clusters", nil) + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(w.Code).To(Equal(http.StatusUnauthorized)) + Expect(nextCalled).To(BeFalse()) + }) + + t.Run("returns 401 on PATCH without identity", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{HeaderName: "X-HyperFleet-Identity"}) + Expect(err).NotTo(HaveOccurred()) + + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + nextCalled = true + }) + + r := httptest.NewRequest(http.MethodPatch, "/api/hyperfleet/v1/clusters/123", nil) + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(w.Code).To(Equal(http.StatusUnauthorized)) + Expect(nextCalled).To(BeFalse()) + }) + + t.Run("returns 401 on DELETE without identity", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{HeaderName: "X-HyperFleet-Identity"}) + Expect(err).NotTo(HaveOccurred()) + + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + nextCalled = true + }) + + r := httptest.NewRequest(http.MethodDelete, "/api/hyperfleet/v1/clusters/123", nil) + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(w.Code).To(Equal(http.StatusUnauthorized)) + Expect(nextCalled).To(BeFalse()) + }) + + t.Run("allows POST with identity from header", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{HeaderName: "X-HyperFleet-Identity"}) + Expect(err).NotTo(HaveOccurred()) + + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + Expect(GetUsernameFromContext(r.Context())).To(Equal("user@example.com")) + }) + + r := httptest.NewRequest(http.MethodPost, "/api/hyperfleet/v1/clusters", nil) + r.Header.Set("X-HyperFleet-Identity", "user@example.com") + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(called).To(BeTrue()) + Expect(w.Code).To(Equal(http.StatusOK)) + }) + + t.Run("returns 401 on PUT without identity", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{HeaderName: "X-HyperFleet-Identity"}) + Expect(err).NotTo(HaveOccurred()) + + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + nextCalled = true + }) + + r := httptest.NewRequest(http.MethodPut, "/api/hyperfleet/v1/clusters/123", nil) + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(w.Code).To(Equal(http.StatusUnauthorized)) + Expect(nextCalled).To(BeFalse()) + }) + + t.Run("returns 401 on POST with oversized header", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{HeaderName: "X-HyperFleet-Identity"}) + Expect(err).NotTo(HaveOccurred()) + + nextCalled := false + next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + nextCalled = true + }) + + r := httptest.NewRequest(http.MethodPost, "/api/hyperfleet/v1/clusters", nil) + r.Header.Set("X-HyperFleet-Identity", strings.Repeat("a", maxCallerIdentityLen+1)) + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(w.Code).To(Equal(http.StatusUnauthorized)) + Expect(nextCalled).To(BeFalse()) + }) + + t.Run("POST with empty header falls back to JWT", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{ + JWTIdentityClaim: "email", + HeaderName: "X-HyperFleet-Identity", + }) + Expect(err).NotTo(HaveOccurred()) + + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + Expect(GetUsernameFromContext(r.Context())).To(Equal("jwt@example.com")) + }) + + r := httptest.NewRequest(http.MethodPost, "/api/hyperfleet/v1/clusters", nil) + r = r.WithContext(contextWithClaims(jwt.MapClaims{"email": "jwt@example.com"})) + r.Header.Set("X-HyperFleet-Identity", "") + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(called).To(BeTrue()) + Expect(w.Code).To(Equal(http.StatusOK)) + }) + + t.Run("header identity takes precedence over JWT on POST", func(t *testing.T) { + RegisterTestingT(t) + mw, err := NewCallerIdentityMiddleware(CallerIdentityConfig{ + JWTIdentityClaim: "email", + HeaderName: "X-HyperFleet-Identity", + }) + Expect(err).NotTo(HaveOccurred()) + + called := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + Expect(GetUsernameFromContext(r.Context())).To(Equal("header@gateway.com")) + }) + + r := httptest.NewRequest(http.MethodPost, "/api/hyperfleet/v1/clusters", nil) + r = r.WithContext(contextWithClaims(jwt.MapClaims{"email": "jwt@example.com"})) + r.Header.Set("X-HyperFleet-Identity", "header@gateway.com") + w := httptest.NewRecorder() + mw.ResolveCallerIdentity(next).ServeHTTP(w, r) + + Expect(called).To(BeTrue()) + Expect(w.Code).To(Equal(http.StatusOK)) + }) +} diff --git a/pkg/config/flags.go b/pkg/config/flags.go index 8c30c9a9..0bb4cef1 100644 --- a/pkg/config/flags.go +++ b/pkg/config/flags.go @@ -31,6 +31,16 @@ func AddServerFlags(cmd *cobra.Command) { cmd.Flags().Bool("server-jwt-enabled", defaults.JWT.Enabled, "Enable JWT authentication") cmd.Flags().String("server-jwt-issuer-url", defaults.JWT.IssuerURL, "Expected JWT issuer URL for token validation") cmd.Flags().String("server-jwt-audience", defaults.JWT.Audience, "Expected JWT audience (optional)") + cmd.Flags().String( + "server-jwt-identity-claim", + defaults.JWT.IdentityClaim, + "JWT claim used as request identity for audit", + ) + cmd.Flags().String( + "server-identity-header", + defaults.IdentityHeader, + "HTTP header name for caller identity (overrides JWT claim when set); leave empty to disable", + ) cmd.Flags().String("server-jwk-cert-file", defaults.JWK.CertFile, "JWK certificate file path") cmd.Flags().String("server-jwk-cert-url", defaults.JWK.CertURL, "JWK certificate URL") } diff --git a/pkg/config/loader.go b/pkg/config/loader.go index 2b1daca3..0d173822 100644 --- a/pkg/config/loader.go +++ b/pkg/config/loader.go @@ -179,6 +179,9 @@ func (l *ConfigLoader) validateConfig(config *ApplicationConfig) error { if valErr := config.Server.JWT.Validate(); valErr != nil { return fmt.Errorf("server JWT validation failed: %w", valErr) } + if valErr := config.Server.ValidateIdentityHeader(); valErr != nil { + return fmt.Errorf("server identity header validation failed: %w", valErr) + } if config.Server.JWT.Enabled && config.Server.JWK.CertFile == "" && config.Server.JWK.CertURL == "" { @@ -315,6 +318,8 @@ func (l *ConfigLoader) bindAllEnvVars() { l.bindEnv("server.jwt.enabled") l.bindEnv("server.jwt.issuer_url") l.bindEnv("server.jwt.audience") + l.bindEnv("server.jwt.identity_claim") + l.bindEnv("server.identity_header") l.bindEnv("server.jwk.cert_file") l.bindEnv("server.jwk.cert_url") // Database config @@ -384,6 +389,8 @@ func (l *ConfigLoader) bindFlags(cmd *cobra.Command) { l.bindPFlag("server.jwt.enabled", cmd.Flags().Lookup("server-jwt-enabled")) l.bindPFlag("server.jwt.issuer_url", cmd.Flags().Lookup("server-jwt-issuer-url")) l.bindPFlag("server.jwt.audience", cmd.Flags().Lookup("server-jwt-audience")) + l.bindPFlag("server.jwt.identity_claim", cmd.Flags().Lookup("server-jwt-identity-claim")) + l.bindPFlag("server.identity_header", cmd.Flags().Lookup("server-identity-header")) l.bindPFlag("server.jwk.cert_file", cmd.Flags().Lookup("server-jwk-cert-file")) l.bindPFlag("server.jwk.cert_url", cmd.Flags().Lookup("server-jwk-cert-url")) // Database flags: --db-* -> database.* diff --git a/pkg/config/loader_test.go b/pkg/config/loader_test.go index f3a1130b..eec2fd27 100644 --- a/pkg/config/loader_test.go +++ b/pkg/config/loader_test.go @@ -348,6 +348,7 @@ func TestConfigLoader_DefaultValues(t *testing.T) { Expect(cfg.Server.Timeouts.Write.Seconds()).To(Equal(float64(30)), "Default write timeout") Expect(cfg.Server.TLS.Enabled).To(BeFalse(), "Default TLS disabled") Expect(cfg.Server.JWT.Enabled).To(BeTrue(), "Default JWT enabled") + Expect(cfg.Server.JWT.IdentityClaim).To(Equal("email"), "Default JWT identity claim") Expect(cfg.Database.Dialect).To(Equal("postgres"), "Default database dialect") Expect(cfg.Database.Port).To(Equal(5432), "Default database port") Expect(cfg.Logging.Level).To(Equal("info"), "Default log level") diff --git a/pkg/config/server.go b/pkg/config/server.go index fd15f428..45a1aa08 100755 --- a/pkg/config/server.go +++ b/pkg/config/server.go @@ -5,6 +5,8 @@ import ( "net" "strconv" "time" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/validation" ) // ServerConfig holds HTTP/HTTPS server configuration @@ -13,6 +15,7 @@ type ServerConfig struct { Hostname string `mapstructure:"hostname" json:"hostname" validate:"omitempty,hostname|ip"` Host string `mapstructure:"host" json:"host" validate:"required,hostname|ip"` OpenAPISchemaPath string `mapstructure:"openapi_schema_path" json:"openapi_schema_path"` + IdentityHeader string `mapstructure:"identity_header" json:"identity_header"` JWK JWKConfig `mapstructure:"jwk" json:"jwk" validate:"required"` TLS TLSConfig `mapstructure:"tls" json:"tls" validate:"required"` JWT JWTConfig `mapstructure:"jwt" json:"jwt" validate:"required"` @@ -61,9 +64,10 @@ func (c *TLSConfig) Validate() error { // JWTConfig holds JWT authentication configuration type JWTConfig struct { - IssuerURL string `mapstructure:"issuer_url" json:"issuer_url" validate:"omitempty,url"` - Audience string `mapstructure:"audience" json:"audience"` - Enabled bool `mapstructure:"enabled" json:"enabled"` + IssuerURL string `mapstructure:"issuer_url" json:"issuer_url" validate:"omitempty,url"` + Audience string `mapstructure:"audience" json:"audience"` + IdentityClaim string `mapstructure:"identity_claim" json:"identity_claim"` + Enabled bool `mapstructure:"enabled" json:"enabled"` } func (c *JWTConfig) Validate() error { @@ -73,6 +77,20 @@ func (c *JWTConfig) Validate() error { if c.IssuerURL == "" { return fmt.Errorf("server.jwt.issuer_url is required when jwt is enabled") } + if c.IdentityClaim == "" { + return fmt.Errorf("server.jwt.identity_claim is required when jwt is enabled") + } + return nil +} + +// ValidateIdentityHeader validates the identity header name if set. +func (s *ServerConfig) ValidateIdentityHeader() error { + if s.IdentityHeader == "" { + return nil + } + if validation.IsForbiddenIdentityHeaderName(s.IdentityHeader) { + return fmt.Errorf("server.identity_header %q is not allowed", s.IdentityHeader) + } return nil } @@ -100,9 +118,10 @@ func NewServerConfig() *ServerConfig { KeyFile: "", }, JWT: JWTConfig{ - Enabled: true, - IssuerURL: "", - Audience: "", + Enabled: true, + IssuerURL: "", + Audience: "", + IdentityClaim: "email", }, JWK: JWKConfig{ CertFile: "", diff --git a/pkg/config/server_test.go b/pkg/config/server_test.go index dd7eb521..dabb1e6b 100644 --- a/pkg/config/server_test.go +++ b/pkg/config/server_test.go @@ -25,9 +25,45 @@ func TestJWTConfig_Validate(t *testing.T) { t.Run("enabled JWT with issuer URL passes", func(t *testing.T) { RegisterTestingT(t) - cfg := JWTConfig{Enabled: true, IssuerURL: "https://sso.example.com/auth/realms/test"} + cfg := JWTConfig{ + Enabled: true, + IssuerURL: "https://sso.example.com/auth/realms/test", + IdentityClaim: "email", + } Expect(cfg.Validate()).To(Succeed()) }) + + t.Run("enabled JWT without identity claim fails", func(t *testing.T) { + RegisterTestingT(t) + cfg := JWTConfig{Enabled: true, IssuerURL: "https://sso.example.com/auth/realms/test", IdentityClaim: ""} + err := cfg.Validate() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("identity_claim")) + }) +} + +func TestServerConfig_ValidateIdentityHeader(t *testing.T) { + RegisterTestingT(t) + + t.Run("empty identity header requires nothing", func(t *testing.T) { + RegisterTestingT(t) + cfg := &ServerConfig{} + Expect(cfg.ValidateIdentityHeader()).To(Succeed()) + }) + + t.Run("forbidden header name fails", func(t *testing.T) { + RegisterTestingT(t) + cfg := &ServerConfig{IdentityHeader: "Authorization"} + err := cfg.ValidateIdentityHeader() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not allowed")) + }) + + t.Run("valid header name passes", func(t *testing.T) { + RegisterTestingT(t) + cfg := &ServerConfig{IdentityHeader: "X-HyperFleet-Identity"} + Expect(cfg.ValidateIdentityHeader()).To(Succeed()) + }) } func TestTimeoutsConfig_Validate(t *testing.T) { diff --git a/pkg/services/cluster.go b/pkg/services/cluster.go index d3a0f4e0..2d72e81f 100644 --- a/pkg/services/cluster.go +++ b/pkg/services/cluster.go @@ -71,10 +71,10 @@ func (s *sqlClusterService) Get(ctx context.Context, id string) (*api.Cluster, * func (s *sqlClusterService) Create(ctx context.Context, cluster *api.Cluster) (*api.Cluster, *errors.ServiceError) { if cluster.CreatedBy == "" { - cluster.CreatedBy = defaultSystemUser + cluster.CreatedBy = actorFromContext(ctx) } if cluster.UpdatedBy == "" { - cluster.UpdatedBy = defaultSystemUser + cluster.UpdatedBy = actorFromContext(ctx) } if cluster.Generation == 0 { cluster.Generation = 1 @@ -117,6 +117,7 @@ func (s *sqlClusterService) Patch( } cluster.IncrementGeneration() + cluster.UpdatedBy = actorFromContext(ctx) if saveErr := s.clusterDao.Save(ctx, cluster); saveErr != nil { return nil, handleUpdateError(api.ResourceTypeCluster, saveErr) @@ -130,7 +131,7 @@ func (s *sqlClusterService) Patch( } // SoftDelete marks a cluster for deletion by setting DeletedTime and -// DeletedBy to the current time and system@hyperfleet.local. +// DeletedBy to the current time and the caller identity from context. // If already marked, it returns the cluster unchanged. Cascades the deletion timestamp to all child nodepools. // Actual removal is handled by adapters detecting the new generation and triggering hard deletion asynchronously. func (s *sqlClusterService) SoftDelete(ctx context.Context, id string) (*api.Cluster, *errors.ServiceError) { @@ -145,7 +146,7 @@ func (s *sqlClusterService) SoftDelete(ctx context.Context, id string) (*api.Clu } deletedTime := time.Now().UTC().Truncate(time.Microsecond) - deletedBy := defaultSystemUser + deletedBy := actorFromContext(ctx) cluster.MarkDeleted(deletedBy, deletedTime) cluster.IncrementGeneration() diff --git a/pkg/services/node_pool.go b/pkg/services/node_pool.go index a2d77c3c..6c43845b 100644 --- a/pkg/services/node_pool.go +++ b/pkg/services/node_pool.go @@ -81,10 +81,10 @@ func (s *sqlNodePoolService) Get(ctx context.Context, id string) (*api.NodePool, func (s *sqlNodePoolService) Create(ctx context.Context, nodePool *api.NodePool) (*api.NodePool, *errors.ServiceError) { if nodePool.CreatedBy == "" { - nodePool.CreatedBy = defaultSystemUser + nodePool.CreatedBy = actorFromContext(ctx) } if nodePool.UpdatedBy == "" { - nodePool.UpdatedBy = defaultSystemUser + nodePool.UpdatedBy = actorFromContext(ctx) } if nodePool.Generation == 0 { nodePool.Generation = 1 @@ -127,6 +127,7 @@ func (s *sqlNodePoolService) Patch( } nodePool.IncrementGeneration() + nodePool.UpdatedBy = actorFromContext(ctx) if saveErr := s.nodePoolDao.Save(ctx, nodePool); saveErr != nil { return nil, handleUpdateError(api.ResourceTypeNodePool, saveErr) @@ -186,7 +187,7 @@ func (s *sqlNodePoolService) SoftDelete(ctx context.Context, nodePoolID string) } t := time.Now().UTC().Truncate(time.Microsecond) - deletedBy := defaultSystemUser + deletedBy := actorFromContext(ctx) nodePool.MarkDeleted(deletedBy, t) nodePool.IncrementGeneration() diff --git a/pkg/services/util.go b/pkg/services/util.go index 88281495..02194cbc 100755 --- a/pkg/services/util.go +++ b/pkg/services/util.go @@ -10,6 +10,7 @@ import ( "gorm.io/gorm" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api" + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/auth" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/db" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/errors" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/logger" @@ -113,3 +114,10 @@ func buildAdapterSummaries(ctx context.Context, statuses api.AdapterStatusList) } return summaries } + +func actorFromContext(ctx context.Context) string { + if caller := auth.GetUsernameFromContext(ctx); caller != "" { + return caller + } + return defaultSystemUser +} diff --git a/pkg/validation/identity_header.go b/pkg/validation/identity_header.go new file mode 100644 index 00000000..b077f5f3 --- /dev/null +++ b/pkg/validation/identity_header.go @@ -0,0 +1,19 @@ +package validation + +import "net/http" + +var forbiddenIdentityHeaderNames = map[string]struct{}{ + "Authorization": {}, + "Cookie": {}, + "Set-Cookie": {}, + "X-Api-Key": {}, + "X-Auth-Token": {}, + "X-Forwarded-Authorization": {}, + "Proxy-Authorization": {}, +} + +// IsForbiddenIdentityHeaderName reports whether name must not be used as the caller identity header. +func IsForbiddenIdentityHeaderName(name string) bool { + _, ok := forbiddenIdentityHeaderNames[http.CanonicalHeaderKey(name)] + return ok +} diff --git a/pkg/validation/identity_header_test.go b/pkg/validation/identity_header_test.go new file mode 100644 index 00000000..cb9e0167 --- /dev/null +++ b/pkg/validation/identity_header_test.go @@ -0,0 +1,13 @@ +package validation + +import ( + "testing" + + . "github.com/onsi/gomega" +) + +func TestIsForbiddenIdentityHeaderName(t *testing.T) { + RegisterTestingT(t) + Expect(IsForbiddenIdentityHeaderName("Authorization")).To(BeTrue()) + Expect(IsForbiddenIdentityHeaderName("X-HyperFleet-Identity")).To(BeFalse()) +} diff --git a/plugins/channels/plugin.go b/plugins/channels/plugin.go index 450301b3..681f2e84 100644 --- a/plugins/channels/plugin.go +++ b/plugins/channels/plugin.go @@ -7,7 +7,6 @@ import ( "github.com/openshift-hyperfleet/hyperfleet-api/cmd/hyperfleet-api/environments" "github.com/openshift-hyperfleet/hyperfleet-api/cmd/hyperfleet-api/server" - "github.com/openshift-hyperfleet/hyperfleet-api/pkg/auth" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/handlers" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/registry" "github.com/openshift-hyperfleet/hyperfleet-api/plugins/resources" @@ -27,11 +26,7 @@ func init() { SearchDisallowedFields: []string{"spec"}, }) - server.RegisterRoutes(pluralChannels, func( - apiV1Router *mux.Router, - svc server.ServicesInterface, - authMiddleware auth.JWTMiddleware, - ) { + server.RegisterRoutes(pluralChannels, func(apiV1Router *mux.Router, svc server.ServicesInterface) { envServices := svc.(*environments.Services) descriptor := registry.MustGet(kindChannel) h := handlers.NewResourceHandler( @@ -45,7 +40,5 @@ func init() { r.HandleFunc("/{id}", h.Get).Methods(http.MethodGet) r.HandleFunc("/{id}", h.Patch).Methods(http.MethodPatch) r.HandleFunc("/{id}", h.Delete).Methods(http.MethodDelete) - - r.Use(authMiddleware.AuthenticateAccountJWT) }) } diff --git a/plugins/clusters/plugin.go b/plugins/clusters/plugin.go index 2375d048..08bdf411 100644 --- a/plugins/clusters/plugin.go +++ b/plugins/clusters/plugin.go @@ -9,7 +9,6 @@ import ( "github.com/openshift-hyperfleet/hyperfleet-api/cmd/hyperfleet-api/server" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api/presenters" - "github.com/openshift-hyperfleet/hyperfleet-api/pkg/auth" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/dao" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/handlers" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/services" @@ -52,7 +51,7 @@ func init() { }) // Routes registration - server.RegisterRoutes("clusters", func(apiV1Router *mux.Router, services server.ServicesInterface, authMiddleware auth.JWTMiddleware) { + server.RegisterRoutes("clusters", func(apiV1Router *mux.Router, services server.ServicesInterface) { envServices := services.(*environments.Services) clusterHandler := handlers.NewClusterHandler(Service(envServices), generic.Service(envServices)) @@ -85,8 +84,6 @@ func init() { nodepoolStatusHandler := handlers.NewNodePoolStatusHandler(adapterStatus.Service(envServices), nodePools.Service(envServices)) clustersRouter.HandleFunc("/{id}/nodepools/{nodepool_id}/statuses", nodepoolStatusHandler.List).Methods(http.MethodGet) clustersRouter.HandleFunc("/{id}/nodepools/{nodepool_id}/statuses", nodepoolStatusHandler.Create).Methods(http.MethodPut) - - clustersRouter.Use(authMiddleware.AuthenticateAccountJWT) }) // REMOVED: Controller registration - Sentinel handles orchestration diff --git a/plugins/nodePools/plugin.go b/plugins/nodePools/plugin.go index 1d73af97..ca122375 100644 --- a/plugins/nodePools/plugin.go +++ b/plugins/nodePools/plugin.go @@ -9,7 +9,6 @@ import ( "github.com/openshift-hyperfleet/hyperfleet-api/cmd/hyperfleet-api/server" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api/presenters" - "github.com/openshift-hyperfleet/hyperfleet-api/pkg/auth" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/dao" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/handlers" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/services" @@ -50,7 +49,7 @@ func init() { }) // Routes registration - server.RegisterRoutes("nodePools", func(apiV1Router *mux.Router, services server.ServicesInterface, authMiddleware auth.JWTMiddleware) { + server.RegisterRoutes("nodePools", func(apiV1Router *mux.Router, services server.ServicesInterface) { envServices := services.(*environments.Services) nodePoolHandler := handlers.NewNodePoolHandler(Service(envServices), generic.Service(envServices)) @@ -58,8 +57,6 @@ func init() { // GET /api/hyperfleet/v1/nodepools - List all nodepools nodePoolsRouter := apiV1Router.PathPrefix("/nodepools").Subrouter() nodePoolsRouter.HandleFunc("", nodePoolHandler.List).Methods(http.MethodGet) - - nodePoolsRouter.Use(authMiddleware.AuthenticateAccountJWT) }) // REMOVED: Controller registration - Sentinel handles orchestration diff --git a/plugins/versions/plugin.go b/plugins/versions/plugin.go index 7edfd8ed..c3852ee7 100644 --- a/plugins/versions/plugin.go +++ b/plugins/versions/plugin.go @@ -7,7 +7,6 @@ import ( "github.com/openshift-hyperfleet/hyperfleet-api/cmd/hyperfleet-api/environments" "github.com/openshift-hyperfleet/hyperfleet-api/cmd/hyperfleet-api/server" - "github.com/openshift-hyperfleet/hyperfleet-api/pkg/auth" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/handlers" "github.com/openshift-hyperfleet/hyperfleet-api/pkg/registry" "github.com/openshift-hyperfleet/hyperfleet-api/plugins/resources" @@ -29,11 +28,7 @@ func init() { SearchDisallowedFields: []string{"spec"}, }) - server.RegisterRoutes(pluralVersions, func( - apiV1Router *mux.Router, - svc server.ServicesInterface, - authMiddleware auth.JWTMiddleware, - ) { + server.RegisterRoutes(pluralVersions, func(apiV1Router *mux.Router, svc server.ServicesInterface) { envServices := svc.(*environments.Services) channelDescriptor := registry.MustGet("Channel") descriptor := registry.MustGet(kindVersion) @@ -48,7 +43,5 @@ func init() { r.HandleFunc("/{id}", h.GetByOwner).Methods(http.MethodGet) r.HandleFunc("/{id}", h.PatchByOwner).Methods(http.MethodPatch) r.HandleFunc("/{id}", h.DeleteByOwner).Methods(http.MethodDelete) - - r.Use(authMiddleware.AuthenticateAccountJWT) }) } diff --git a/test/helper.go b/test/helper.go index 5aabea99..f78c67a1 100755 --- a/test/helper.go +++ b/test/helper.go @@ -369,6 +369,21 @@ func WithAuthToken(ctx context.Context) openapi.RequestEditorFn { } } +// WithIdentityHeader returns a RequestEditorFn that sets the caller identity header. +func WithIdentityHeader(headerName, headerValue string) openapi.RequestEditorFn { + return func(_ context.Context, req *http.Request) error { + if headerName != "" && headerValue != "" { + req.Header.Set(headerName, headerValue) + } + return nil + } +} + +// IdentityHeaderName returns the configured identity header name for integration tests. +func IdentityHeaderName() string { + return environments.Environment().Config.Server.IdentityHeader +} + func (helper *Helper) StartJWKCertServerMock() (teardown func() error) { jwkURL, teardown = mocks.NewJWKCertServerMock(helper.T, helper.JWTCA, jwkKID, jwkAlg) helper.Env().Config.Server.JWK.CertURL = jwkURL diff --git a/test/integration/caller_identity_test.go b/test/integration/caller_identity_test.go new file mode 100644 index 00000000..c1905e40 --- /dev/null +++ b/test/integration/caller_identity_test.go @@ -0,0 +1,311 @@ +package integration + +import ( + "fmt" + "net/http" + "strings" + "testing" + + . "github.com/onsi/gomega" + + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/api/openapi" + "github.com/openshift-hyperfleet/hyperfleet-api/pkg/util" + "github.com/openshift-hyperfleet/hyperfleet-api/test" +) + +func shortClusterName(prefix, id string) string { + suffix := strings.ReplaceAll(id, "-", "") + if len(suffix) > 12 { + suffix = suffix[:12] + } + return fmt.Sprintf("%s-%s", prefix, suffix) +} + +func TestCallerIdentityCreate(t *testing.T) { + cases := []struct { + name string + namePrefix string + email string + headerActor string + setHeader bool + }{ + { + name: "header present overrides JWT", + namePrefix: "ci-hdr", + email: "jwt-user@example.com", + setHeader: true, + headerActor: "gateway-user@example.com", + }, + { + name: "header absent uses JWT claim", + namePrefix: "ci-jwt", + email: "jwt-only-user@example.com", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + h, client := test.RegisterIntegration(t) + + account := h.NewAccount(h.NewID(), "Test User", tc.email) + ctx := h.NewAuthenticatedContext(account) + + wantCreatedBy := account.Email + if tc.setHeader { + wantCreatedBy = tc.headerActor + } + + clusterInput := openapi.ClusterCreateRequest{ + Kind: util.PtrString("Cluster"), + Name: shortClusterName(tc.namePrefix, h.NewID()), + Spec: map[string]interface{}{"test": "spec"}, + } + + opts := []openapi.RequestEditorFn{test.WithAuthToken(ctx)} + if tc.setHeader { + opts = append(opts, test.WithIdentityHeader(test.IdentityHeaderName(), tc.headerActor)) + } + + resp, err := client.PostClusterWithResponse( + ctx, + openapi.PostClusterJSONRequestBody(clusterInput), + opts..., + ) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusCreated)) + Expect(resp.JSON201).NotTo(BeNil()) + Expect(string(resp.JSON201.CreatedBy)).To(Equal(wantCreatedBy)) + }) + } +} + +func TestCallerIdentityPatch(t *testing.T) { + cases := []struct { + name string + namePrefix string + createEmail string + patchEmail string + headerActor string + wantUpdatedBy string + setHeader bool + }{ + { + name: "patch with header updates updated_by", + namePrefix: "ci-patch-hdr", + createEmail: "creator@example.com", + patchEmail: "creator@example.com", + setHeader: true, + headerActor: "updater@example.com", + wantUpdatedBy: "updater@example.com", + }, + { + name: "patch without header uses JWT identity", + namePrefix: "ci-patch-jwt", + createEmail: "creator@example.com", + patchEmail: "patcher@example.com", + wantUpdatedBy: "patcher@example.com", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + h, client := test.RegisterIntegration(t) + + // Create cluster with one identity. + createAccount := h.NewAccount(h.NewID(), "Creator", tc.createEmail) + createCtx := h.NewAuthenticatedContext(createAccount) + + clusterInput := openapi.ClusterCreateRequest{ + Kind: util.PtrString("Cluster"), + Name: shortClusterName(tc.namePrefix, h.NewID()), + Spec: map[string]interface{}{"test": "spec"}, + } + + createResp, err := client.PostClusterWithResponse( + createCtx, + openapi.PostClusterJSONRequestBody(clusterInput), + test.WithAuthToken(createCtx), + ) + Expect(err).NotTo(HaveOccurred()) + Expect(createResp.StatusCode()).To(Equal(http.StatusCreated)) + clusterID := createResp.JSON201.Id + + // Patch cluster with a different identity. + patchAccount := h.NewAccount(h.NewID(), "Patcher", tc.patchEmail) + patchCtx := h.NewAuthenticatedContext(patchAccount) + + newSpec := openapi.ClusterSpec{"test": "patched"} + patchBody := openapi.PatchClusterByIdJSONRequestBody{Spec: &newSpec} + + opts := []openapi.RequestEditorFn{test.WithAuthToken(patchCtx)} + if tc.setHeader { + opts = append(opts, test.WithIdentityHeader(test.IdentityHeaderName(), tc.headerActor)) + } + + patchResp, err := client.PatchClusterByIdWithResponse( + patchCtx, + *clusterID, + patchBody, + opts..., + ) + Expect(err).NotTo(HaveOccurred()) + Expect(patchResp.StatusCode()).To(Equal(http.StatusOK)) + Expect(patchResp.JSON200).NotTo(BeNil()) + Expect(string(patchResp.JSON200.UpdatedBy)).To(Equal(tc.wantUpdatedBy)) + // created_by must remain unchanged. + Expect(string(patchResp.JSON200.CreatedBy)).To(Equal(tc.createEmail)) + }) + } +} + +func TestCallerIdentityMultiplePatches(t *testing.T) { + RegisterTestingT(t) + + h, client := test.RegisterIntegration(t) + + // Create cluster as user-a. + accountA := h.NewAccount(h.NewID(), "User A", "user-a@example.com") + ctxA := h.NewAuthenticatedContext(accountA) + + clusterInput := openapi.ClusterCreateRequest{ + Kind: util.PtrString("Cluster"), + Name: shortClusterName("ci-multi", h.NewID()), + Spec: map[string]interface{}{"version": "1"}, + } + createResp, err := client.PostClusterWithResponse( + ctxA, openapi.PostClusterJSONRequestBody(clusterInput), test.WithAuthToken(ctxA), + ) + Expect(err).NotTo(HaveOccurred()) + Expect(createResp.StatusCode()).To(Equal(http.StatusCreated)) + clusterID := *createResp.JSON201.Id + + Expect(string(createResp.JSON201.CreatedBy)).To(Equal("user-a@example.com")) + Expect(string(createResp.JSON201.UpdatedBy)).To(Equal("user-a@example.com")) + Expect(createResp.JSON201.Generation).To(Equal(int32(1))) + + // Patch 1: user-b via JWT. + accountB := h.NewAccount(h.NewID(), "User B", "user-b@example.com") + ctxB := h.NewAuthenticatedContext(accountB) + + spec2 := openapi.ClusterSpec{"version": "2"} + patch1, err := client.PatchClusterByIdWithResponse( + ctxB, clusterID, + openapi.PatchClusterByIdJSONRequestBody{Spec: &spec2}, + test.WithAuthToken(ctxB), + ) + Expect(err).NotTo(HaveOccurred()) + Expect(patch1.StatusCode()).To(Equal(http.StatusOK)) + + Expect(string(patch1.JSON200.CreatedBy)).To(Equal("user-a@example.com"), "created_by must never change") + Expect(string(patch1.JSON200.UpdatedBy)).To(Equal("user-b@example.com")) + Expect(patch1.JSON200.Generation).To(Equal(int32(2))) + + // Patch 2: user-c via identity header. + spec3 := openapi.ClusterSpec{"version": "3"} + patch2, err := client.PatchClusterByIdWithResponse( + ctxA, clusterID, + openapi.PatchClusterByIdJSONRequestBody{Spec: &spec3}, + test.WithAuthToken(ctxA), + test.WithIdentityHeader(test.IdentityHeaderName(), "user-c@gateway.com"), + ) + Expect(err).NotTo(HaveOccurred()) + Expect(patch2.StatusCode()).To(Equal(http.StatusOK)) + + Expect(string(patch2.JSON200.CreatedBy)).To(Equal("user-a@example.com"), "created_by must never change") + Expect(string(patch2.JSON200.UpdatedBy)).To(Equal("user-c@gateway.com")) + Expect(patch2.JSON200.Generation).To(Equal(int32(3))) + + // GET confirms persisted state. + getResp, err := client.GetClusterByIdWithResponse(ctxA, clusterID, nil, test.WithAuthToken(ctxA)) + Expect(err).NotTo(HaveOccurred()) + Expect(getResp.StatusCode()).To(Equal(http.StatusOK)) + Expect(string(getResp.JSON200.CreatedBy)).To(Equal("user-a@example.com")) + Expect(string(getResp.JSON200.UpdatedBy)).To(Equal("user-c@gateway.com")) + Expect(getResp.JSON200.Generation).To(Equal(int32(3))) +} + +func TestCallerIdentityDelete(t *testing.T) { + RegisterTestingT(t) + + h, client := test.RegisterIntegration(t) + + // Create cluster with identity header. + account := h.NewAccount(h.NewID(), "Creator", "creator@example.com") + ctx := h.NewAuthenticatedContext(account) + + clusterInput := openapi.ClusterCreateRequest{ + Kind: util.PtrString("Cluster"), + Name: shortClusterName("ci-del", h.NewID()), + Spec: map[string]interface{}{"test": "spec"}, + } + createResp, err := client.PostClusterWithResponse( + ctx, openapi.PostClusterJSONRequestBody(clusterInput), + test.WithAuthToken(ctx), + test.WithIdentityHeader(test.IdentityHeaderName(), "header-creator@corp.com"), + ) + Expect(err).NotTo(HaveOccurred()) + Expect(createResp.StatusCode()).To(Equal(http.StatusCreated)) + clusterID := *createResp.JSON201.Id + Expect(string(createResp.JSON201.CreatedBy)).To(Equal("header-creator@corp.com")) + + // Soft-delete the cluster. + delResp, err := client.DeleteClusterByIdWithResponse(ctx, clusterID, test.WithAuthToken(ctx)) + Expect(err).NotTo(HaveOccurred()) + Expect(delResp.StatusCode()).To(Equal(http.StatusAccepted)) + Expect(delResp.JSON202).NotTo(BeNil()) + Expect(delResp.JSON202.DeletedTime).NotTo(BeNil()) + Expect(delResp.JSON202.DeletedBy).NotTo(BeNil()) + // deleted_by reflects the caller identity. + Expect(string(*delResp.JSON202.DeletedBy)).To(Equal("creator@example.com")) + // created_by is preserved. + Expect(string(delResp.JSON202.CreatedBy)).To(Equal("header-creator@corp.com")) +} + +func TestCallerIdentityEmptyHeaderFallback(t *testing.T) { + RegisterTestingT(t) + + h, client := test.RegisterIntegration(t) + + account := h.NewAccount(h.NewID(), "JWT User", "jwt-fallback@example.com") + ctx := h.NewAuthenticatedContext(account) + + // Create with an empty identity header — should fall back to JWT claim. + clusterInput := openapi.ClusterCreateRequest{ + Kind: util.PtrString("Cluster"), + Name: shortClusterName("ci-empty", h.NewID()), + Spec: map[string]interface{}{"test": "spec"}, + } + resp, err := client.PostClusterWithResponse( + ctx, openapi.PostClusterJSONRequestBody(clusterInput), + test.WithAuthToken(ctx), + test.WithIdentityHeader(test.IdentityHeaderName(), ""), + ) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusCreated)) + Expect(string(resp.JSON201.CreatedBy)).To(Equal("jwt-fallback@example.com")) + Expect(string(resp.JSON201.UpdatedBy)).To(Equal("jwt-fallback@example.com")) +} + +func TestCallerIdentityOversizedHeader(t *testing.T) { + RegisterTestingT(t) + + h, client := test.RegisterIntegration(t) + + account := h.NewAccount(h.NewID(), "Test User", "user@example.com") + ctx := h.NewAuthenticatedContext(account) + + // Create with an oversized identity header (>256 chars) — should be rejected. + clusterInput := openapi.ClusterCreateRequest{ + Kind: util.PtrString("Cluster"), + Name: shortClusterName("ci-long", h.NewID()), + Spec: map[string]interface{}{"test": "spec"}, + } + oversized := strings.Repeat("a", 257) + resp, err := client.PostClusterWithResponse( + ctx, openapi.PostClusterJSONRequestBody(clusterInput), + test.WithAuthToken(ctx), + test.WithIdentityHeader(test.IdentityHeaderName(), oversized), + ) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode()).To(Equal(http.StatusUnauthorized)) +} diff --git a/test/integration/clusters_test.go b/test/integration/clusters_test.go index 30f6b59a..e6b7c114 100644 --- a/test/integration/clusters_test.go +++ b/test/integration/clusters_test.go @@ -958,7 +958,7 @@ func TestClusterSoftDelete(t *testing.T) { Expect(*resp.JSON202.Id).To(Equal(cluster.ID)) Expect(resp.JSON202.DeletedTime).NotTo(BeNil()) Expect(resp.JSON202.DeletedBy).NotTo(BeNil()) - Expect(string(*resp.JSON202.DeletedBy)).To(Equal("system@hyperfleet.local")) + Expect(string(*resp.JSON202.DeletedBy)).To(Equal(account.Email)) }) t.Run("given a cluster with child nodepools, when deleted, then nodepools are cascade soft-deleted in DB", func(t *testing.T) { //nolint:lll @@ -986,7 +986,7 @@ func TestClusterSoftDelete(t *testing.T) { Expect(err).NotTo(HaveOccurred()) Expect(delResp.StatusCode()).To(Equal(http.StatusAccepted)) Expect(delResp.JSON202.DeletedTime).NotTo(BeNil()) - Expect(string(*delResp.JSON202.DeletedBy)).To(Equal("system@hyperfleet.local")) + Expect(string(*delResp.JSON202.DeletedBy)).To(Equal(account.Email)) // Verify cascade via direct DB query dbSession := h.DBFactory.New(ctx) var nodePool api.NodePool diff --git a/test/integration/integration_test.go b/test/integration/integration_test.go index d14578a6..f392d089 100755 --- a/test/integration/integration_test.go +++ b/test/integration/integration_test.go @@ -34,6 +34,9 @@ func TestMain(m *testing.M) { if os.Getenv("HYPERFLEET_SERVER_JWK_CERT_URL") == "" { _ = os.Setenv("HYPERFLEET_SERVER_JWK_CERT_URL", "https://test-idp.example.com/certs") } + if os.Getenv("HYPERFLEET_SERVER_IDENTITY_HEADER") == "" { + _ = os.Setenv("HYPERFLEET_SERVER_IDENTITY_HEADER", "X-HyperFleet-Identity") + } // Set OpenAPI schema path for integration tests if not already set // This enables schema validation middleware during tests diff --git a/test/integration/node_pools_test.go b/test/integration/node_pools_test.go index 81a87514..48271ace 100644 --- a/test/integration/node_pools_test.go +++ b/test/integration/node_pools_test.go @@ -601,7 +601,7 @@ func TestNodePoolSoftDelete(t *testing.T) { Expect(*resp.JSON202.Id).To(Equal(nodePoolID)) Expect(resp.JSON202.DeletedTime).NotTo(BeNil()) Expect(resp.JSON202.DeletedBy).NotTo(BeNil()) - Expect(string(*resp.JSON202.DeletedBy)).To(Equal("system@hyperfleet.local")) + Expect(string(*resp.JSON202.DeletedBy)).To(Equal(account.Email)) }) t.Run("given a nodepool with Ready=True, when deleted, then generation increments and Ready becomes False", func(t *testing.T) { //nolint:lll