Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions engine/artifacts/errors/guard.invalid_header.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions engine/packages/guard/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ pub struct MissingHeader {
pub header: String,
}

#[derive(RivetError, Serialize)]
#[error(
"guard",
"invalid_header",
"Invalid header value.",
"Invalid {header} header: {detail}."
)]
pub struct InvalidHeader {
pub header: String,
pub detail: String,
}

#[derive(RivetError, Serialize)]
#[error(
"guard",
Expand Down
22 changes: 11 additions & 11 deletions engine/packages/guard/src/routing/actor_path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub enum QueryActorQuery {
namespace: String,
name: String,
key: Vec<String>,
bypass_connectable: bool,
skip_ready_wait: bool,
},
GetOrCreate {
namespace: String,
Expand All @@ -40,19 +40,19 @@ pub enum QueryActorQuery {
input: Option<Vec<u8>>,
region: Option<String>,
crash_policy: Option<CrashPolicy>,
bypass_connectable: bool,
skip_ready_wait: bool,
},
}

impl QueryActorQuery {
pub fn bypass_connectable(&self) -> bool {
pub fn skip_ready_wait(&self) -> bool {
match self {
QueryActorQuery::Get {
bypass_connectable, ..
skip_ready_wait, ..
}
| QueryActorQuery::GetOrCreate {
bypass_connectable, ..
} => *bypass_connectable,
skip_ready_wait, ..
} => *skip_ready_wait,
}
}
}
Expand Down Expand Up @@ -97,8 +97,8 @@ struct RvtParams {
crash_policy: Option<String>,
#[serde(default)]
token: Option<String>,
#[serde(default)]
bypass_connectable: bool,
#[serde(default, rename = "skip-ready-wait")]
skip_ready_wait: bool,
}

/// Parse actor routing information from path.
Expand Down Expand Up @@ -244,7 +244,7 @@ fn extract_rvt_params(rvt_params: &[(String, String)]) -> Result<RvtParams> {
.build());
}
let value = match stripped {
"bypass_connectable" => parse_query_bool(value)
"skip-ready-wait" => parse_query_bool(value)
.map(serde_json::Value::Bool)
.unwrap_or_else(|| serde_json::Value::String(value.clone())),
_ => serde_json::Value::String(value.clone()),
Expand Down Expand Up @@ -294,7 +294,7 @@ fn build_actor_query(name: &str, rvt: RvtParams) -> Result<QueryActorQuery> {
namespace: rvt.namespace,
name: name.to_string(),
key,
bypass_connectable: rvt.bypass_connectable,
skip_ready_wait: rvt.skip_ready_wait,
})
}
"getOrCreate" => {
Expand All @@ -319,7 +319,7 @@ fn build_actor_query(name: &str, rvt: RvtParams) -> Result<QueryActorQuery> {
input,
region: rvt.region,
crash_policy,
bypass_connectable: rvt.bypass_connectable,
skip_ready_wait: rvt.skip_ready_wait,
})
}
other => Err(errors::QueryInvalidParams {
Expand Down
6 changes: 3 additions & 3 deletions engine/packages/guard/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ mod ws_health;

pub(crate) const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-target");
pub(crate) const X_RIVET_TOKEN: HeaderName = HeaderName::from_static("x-rivet-token");
pub(crate) const X_RIVET_BYPASS_CONNECTABLE: HeaderName =
HeaderName::from_static("x-rivet-bypass-connectable");
pub(crate) const X_RIVET_SKIP_READY_WAIT: HeaderName =
HeaderName::from_static("x-rivet-skip-ready-wait");
pub(crate) const SEC_WEBSOCKET_PROTOCOL: HeaderName =
HeaderName::from_static("sec-websocket-protocol");
pub(crate) const WS_PROTOCOL_TARGET: &str = "rivet_target.";
pub(crate) const WS_PROTOCOL_ACTOR: &str = "rivet_actor.";
pub(crate) const WS_PROTOCOL_TOKEN: &str = "rivet_token.";
pub(crate) const WS_PROTOCOL_BYPASS_CONNECTABLE: &str = "rivet_bypass_connectable";
pub(crate) const WS_PROTOCOL_SKIP_READY_WAIT: &str = "rivet_skip_ready_wait";

/// Creates the main routing function that handles all incoming requests
#[tracing::instrument(skip_all)]
Expand Down
83 changes: 61 additions & 22 deletions engine/packages/guard/src/routing/pegboard_gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
use rivet_guard_core::{RouteConfig, RouteTarget, RoutingOutput, request_context::RequestContext};

use super::{
SEC_WEBSOCKET_PROTOCOL, WS_PROTOCOL_ACTOR, WS_PROTOCOL_BYPASS_CONNECTABLE, WS_PROTOCOL_TOKEN,
X_RIVET_BYPASS_CONNECTABLE, X_RIVET_TOKEN, actor_path::ParsedActorPath,
SEC_WEBSOCKET_PROTOCOL, WS_PROTOCOL_ACTOR, WS_PROTOCOL_SKIP_READY_WAIT, WS_PROTOCOL_TOKEN,
X_RIVET_SKIP_READY_WAIT, X_RIVET_TOKEN, actor_path::ParsedActorPath,
};
use crate::{
errors,
Expand Down Expand Up @@ -70,22 +70,21 @@

tracing::debug!(?actor_path, "routing using path-based actor routing");

let (actor_id, token, stripped_path, bypass_connectable) = match actor_path {
let (actor_id, token, stripped_path, skip_ready_wait) = match actor_path {
ParsedActorPath::Direct(path) => (
Id::parse(&path.actor_id).context("invalid actor id in path")?,
read_gateway_token_for_path_based(req_ctx, path.token.as_deref())?
.map(ToOwned::to_owned),
path.stripped_path.clone(),
// TODO:
false,
read_skip_ready_wait_for_path_based(req_ctx)?,
),
ParsedActorPath::Query(path) => match resolve_query(ctx, &path.query).await? {
ResolveQueryActorResult::Found { actor_id } => (
actor_id,
read_gateway_token_for_path_based(req_ctx, path.token.as_deref())?
.map(ToOwned::to_owned),
path.stripped_path.clone(),
path.query.bypass_connectable(),
path.query.skip_ready_wait(),
),
ResolveQueryActorResult::Forward { dc_label } => {
let peer_dc = ctx
Expand Down Expand Up @@ -116,7 +115,7 @@
actor_id,
&stripped_path,
token.as_deref(),
bypass_connectable,
skip_ready_wait,
)
.await
.map(Some)
Expand Down Expand Up @@ -148,7 +147,7 @@
set_non_preflight_cors(req_ctx);

// Extract actor ID and token from WebSocket protocol or HTTP headers
let (actor_id_str, token, bypass_connectable) = if req_ctx.is_websocket() {
let (actor_id_str, token, skip_ready_wait) = if req_ctx.is_websocket() {
// For WebSocket, parse the sec-websocket-protocol header
let protocols_header = req_ctx
.headers()
Expand Down Expand Up @@ -179,14 +178,14 @@

let token = protocols
.iter()
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN))

Check warning on line 181 in engine/packages/guard/src/routing/pegboard_gateway/mod.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

Diff in /home/runner/work/rivet/rivet/engine/packages/guard/src/routing/pegboard_gateway/mod.rs
.map(ToOwned::to_owned);

let bypass_connectable = protocols
let skip_ready_wait = protocols
.iter()
.any(|p| p == &WS_PROTOCOL_BYPASS_CONNECTABLE);
.any(|p| p == &WS_PROTOCOL_SKIP_READY_WAIT);

(actor_id, token, bypass_connectable)
(actor_id, token, skip_ready_wait)
} else {
// For HTTP, use headers
let actor_id = req_ctx
Expand All @@ -210,9 +209,9 @@
.context("invalid x-rivet-token header")?
.map(ToOwned::to_owned);

let bypass_connectable = req_ctx.headers().contains_key(X_RIVET_BYPASS_CONNECTABLE);
let skip_ready_wait = read_skip_ready_wait_header(req_ctx)?;

(actor_id.to_string(), token, bypass_connectable)
(actor_id.to_string(), token, skip_ready_wait)
};

// Find actor to route to
Expand All @@ -226,7 +225,7 @@
actor_id,
&stripped_path,
token.as_deref(),
bypass_connectable,
skip_ready_wait,
)
.await
.map(Some)
Expand All @@ -247,7 +246,7 @@
actor_id: Id,
stripped_path: &str,
_token: Option<&str>,
bypass_connectable: bool,
skip_ready_wait: bool,
) -> Result<RoutingOutput> {
// NOTE: Token validation implemented in EE

Expand Down Expand Up @@ -323,7 +322,7 @@
actor_id,
actor,
stripped_path,
bypass_connectable,
skip_ready_wait,
ready_sub2,
stopped_sub2,
fail_sub2,
Expand All @@ -338,7 +337,7 @@
actor_id,
actor,
stripped_path,
bypass_connectable,
skip_ready_wait,
ready_sub,
stopped_sub,
fail_sub,
Expand All @@ -361,7 +360,7 @@
actor_id: Id,
actor: pegboard::ops::actor::get_for_gateway::Output,
stripped_path: &str,
bypass_connectable: bool,
skip_ready_wait: bool,
mut ready_sub: SubscriptionHandle<pegboard::workflows::actor2::Ready>,
mut stopped_sub: SubscriptionHandle<pegboard::workflows::actor2::Stopped>,
mut fail_sub: SubscriptionHandle<pegboard::workflows::actor2::Failed>,
Expand All @@ -378,7 +377,7 @@
}

let envoy_key = if let (Some(envoy_key), true) =
(actor.envoy_key, actor.connectable || bypass_connectable)
(actor.envoy_key, actor.connectable || skip_ready_wait)
{
envoy_key
} else {
Expand Down Expand Up @@ -464,7 +463,7 @@
actor_id: Id,
actor: pegboard::ops::actor::get_for_gateway::Output,
stripped_path: &str,
bypass_connectable: bool,
skip_ready_wait: bool,
mut ready_sub: SubscriptionHandle<pegboard::workflows::actor::Ready>,
mut stopped_sub: SubscriptionHandle<pegboard::workflows::actor::Stopped>,
mut fail_sub: SubscriptionHandle<pegboard::workflows::actor::Failed>,
Expand All @@ -490,7 +489,7 @@
}

let runner_id = if let (Some(runner_id), true) =
(actor.runner_id, actor.connectable || bypass_connectable)
(actor.runner_id, actor.connectable || skip_ready_wait)
{
runner_id
} else {
Expand Down Expand Up @@ -555,7 +554,7 @@
actor_id,
actor,
stripped_path,
bypass_connectable,
skip_ready_wait,
ready_sub2,
stopped_sub2,
fail_sub2,
Expand Down Expand Up @@ -626,6 +625,46 @@
}
}

fn read_skip_ready_wait_for_path_based(req_ctx: &RequestContext) -> Result<bool> {
if req_ctx.is_websocket() {
Ok(req_ctx
.headers()
.get(SEC_WEBSOCKET_PROTOCOL)
.and_then(|protocols| protocols.to_str().ok())
.is_some_and(|protocols| {
protocols
.split(',')
.map(|p| p.trim())
.any(|p| p == WS_PROTOCOL_SKIP_READY_WAIT)
}))
} else {
read_skip_ready_wait_header(req_ctx)
}
}

fn read_skip_ready_wait_header(req_ctx: &RequestContext) -> Result<bool> {
let Some(value) = req_ctx.headers().get(X_RIVET_SKIP_READY_WAIT) else {
return Ok(false);

Check warning on line 647 in engine/packages/guard/src/routing/pegboard_gateway/mod.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

Diff in /home/runner/work/rivet/rivet/engine/packages/guard/src/routing/pegboard_gateway/mod.rs
};

let value = value.to_str().context("invalid x-rivet-skip-ready-wait header")?;
parse_skip_ready_wait_bool(value).ok_or_else(|| {
crate::errors::InvalidHeader {
header: X_RIVET_SKIP_READY_WAIT.to_string(),
detail: "expected true, false, 1, or 0".to_string(),
}
.build()
})
}

fn parse_skip_ready_wait_bool(value: &str) -> Option<bool> {
match value {
"true" | "1" => Some(true),
"false" | "0" => Some(false),
_ => None,
}
}

/// Waits for initial delay, then periodically checks for runner pool errors.
///
/// Returns `true` if the pool has an active error, `false` otherwise.
Expand Down
Loading
Loading