From 6b5aa7b9b60adc18c8f2857eaf4cdcd247de3b33 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Wed, 6 May 2026 17:22:38 -0700 Subject: [PATCH] chore(rivetkit): rewrite work registry + fix waituntil not preventing sleep --- .agent/notes/driver-test-progress.md | 11 + .../rust/envoy-client/src/async_counter.rs | 14 + .../rivetkit-core/src/actor/context.rs | 270 +++++++----------- .../packages/rivetkit-core/src/actor/mod.rs | 3 +- .../packages/rivetkit-core/src/actor/sleep.rs | 201 ++++++++++++- .../rivetkit-core/src/actor/work_registry.rs | 128 ++++++++- .../packages/rivetkit-core/src/lib.rs | 3 +- .../packages/rivetkit-core/tests/sleep.rs | 89 +++++- .../packages/rivetkit-napi/index.d.ts | 1 + .../rivetkit-napi/src/actor_context.rs | 11 +- .../packages/rivetkit-wasm/src/lib.rs | 10 +- .../driver-test-suite/registry-static.ts | 4 + .../fixtures/driver-test-suite/sleep.ts | 112 ++++++++ .../rivetkit/src/registry/napi-runtime.ts | 6 + .../packages/rivetkit/src/registry/native.ts | 55 +--- .../packages/rivetkit/src/registry/runtime.ts | 1 + .../rivetkit/src/registry/wasm-runtime.ts | 9 + .../rivetkit/tests/driver/actor-sleep.test.ts | 103 +++++++ 18 files changed, 789 insertions(+), 242 deletions(-) diff --git a/.agent/notes/driver-test-progress.md b/.agent/notes/driver-test-progress.md index 3108b239af..fe770e598f 100644 --- a/.agent/notes/driver-test-progress.md +++ b/.agent/notes/driver-test-progress.md @@ -57,3 +57,14 @@ Scope: DB driver tests only - 2026-05-03 18:28 PDT actor-db-stress rerun [native]: PASS (5 passed, 40 skipped, 25.2s). - 2026-05-03 18:29 PDT actor-db-init-order rerun [native]: PASS (6 passed, 48 skipped, 6.6s). - 2026-05-03 18:29 PDT DB TESTS RERUN COMPLETE [native only] - 6/6 DB file groups passed. +- 2026-05-06 21:19 PDT actor-sleep [native/local/bare]: PASS (25 passed, 200 skipped, 53.4s). Relevant PR validation for sleep/work-registry changes. +- 2026-05-06 21:19 PDT actor-sleep [wasm/remote/bare]: FAIL (23 passed, 2 failed, 200 skipped, 49.1s). Failing tests: `waitUntil can broadcast before sleep disconnect` expected waitUntilMessageCount 1 but got 0; `waitUntil in onSleep keeps c.vars available during grace` logged `Cannot set properties of undefined`. +- 2026-05-06 21:19 PDT actor-sleep [wasm/remote/bare] narrowed: FAIL `waitUntil can broadcast before sleep disconnect` expected waitUntilMessageCount 1 but got 0. +- 2026-05-06 21:19 PDT actor-sleep [wasm/remote/bare] narrowed: FAIL `waitUntil in onSleep keeps c.vars available during grace` logged `actor wait_until promise rejected ... Cannot set properties of undefined (setting 'dirty')`. +- 2026-05-06 21:29 PDT actor-sleep rerun after wasm-safe shutdown-work wait [wasm/remote/bare]: FAIL (24 passed, 1 failed, 200 skipped, 49.5s). Fixed `waitUntil in onSleep keeps c.vars available during grace`; remaining failure is `waitUntil can broadcast before sleep disconnect` expected waitUntilMessageCount 1 but got 0. +- 2026-05-06 21:29 PDT actor-sleep rerun after NAPI rebuild [native/local/bare]: FAIL (24 passed, 1 failed, 200 skipped, 51.0s). Same remaining `waitUntil can broadcast before sleep disconnect` failure reproduced standalone on native. +- 2026-05-08 16:52 PDT NAPI-focused rerun after removing tracked-shutdown synthetic timeout. Wasm issue intentionally ignored per request. +- 2026-05-08 16:52 PDT actor-sleep [native/local/bare]: PASS (25 passed, 55.2s). Covers waitUntil shutdown grace, raw websocket callback tracking, and `c.vars` after grace deadline. +- 2026-05-08 16:52 PDT raw-websocket [native/local/bare]: PASS (16 passed, 14.0s). Covers raw websocket callback tracking. +- 2026-05-08 16:52 PDT actor-conn-state [native/local/bare]: PASS (11 passed, 10.9s). Covers connection lifecycle and disconnect accounting through the work registry. +- 2026-05-08 16:52 PDT actor-sleep-db [native/local/bare]: PASS (26 passed, 70.9s). Covers DB close timing during sleep shutdown and waitUntil state persistence. diff --git a/engine/sdks/rust/envoy-client/src/async_counter.rs b/engine/sdks/rust/envoy-client/src/async_counter.rs index e31b697af7..19708bda07 100644 --- a/engine/sdks/rust/envoy-client/src/async_counter.rs +++ b/engine/sdks/rust/envoy-client/src/async_counter.rs @@ -102,6 +102,20 @@ impl AsyncCounter { } } } + + pub async fn wait_zero_unbounded(&self) { + loop { + let notified = self.zero_notify.notified(); + tokio::pin!(notified); + notified.as_mut().enable(); + + if self.value.load(Ordering::Acquire) == 0 { + return; + } + + notified.await; + } + } } impl Default for AsyncCounter { diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/context.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/context.rs index f734ffd33e..935f65eb4b 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/context.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/context.rs @@ -36,7 +36,7 @@ use crate::actor::task::LifecycleEvent; #[cfg(not(target_arch = "wasm32"))] use crate::actor::task::{LIFECYCLE_EVENT_INBOX_CHANNEL, actor_channel_overloaded_error}; use crate::actor::task_types::UserTaskKind; -use crate::actor::work_registry::RegionGuard; +use crate::actor::work_registry::{ActorWorkKind, CountGuard, RegionGuard}; use crate::error::{ActorLifecycle as ActorLifecycleError, ActorRuntime}; use crate::inspector::{Inspector, InspectorSnapshot}; use crate::kv::Kv; @@ -128,7 +128,6 @@ pub(crate) struct ActorContextInner { pub(super) connection_disconnect_state: Mutex<()>, pub(super) sleep: SleepState, activity: ActivityState, - pending_disconnect_count: AtomicUsize, sleep_requested: AtomicBool, destroy_requested: AtomicBool, destroy_completed: AtomicBool, @@ -296,7 +295,6 @@ impl ActorContext { connection_disconnect_state: Mutex::new(()), sleep, activity: ActivityState::default(), - pending_disconnect_count: AtomicUsize::new(0), sleep_requested: AtomicBool::new(false), destroy_requested: AtomicBool::new(false), destroy_completed: AtomicBool::new(false), @@ -543,83 +541,34 @@ impl ActorContext { #[cfg(not(feature = "wasm-runtime"))] pub fn wait_until(&self, future: impl Future + Send + 'static) { - if Handle::try_current().is_err() { - tracing::warn!("skipping wait_until without a tokio runtime"); - return; - } - - let ctx = self.clone(); - // Intentionally detached but tracked by the actor sleep state: waitUntil work - // is a public side task that shutdown drains/aborts through - // `shutdown_tasks`, not an ActorTask dispatch child. - self.track_shutdown_task(async move { - ctx.record_user_task_started(UserTaskKind::WaitUntil); - let started_at = Instant::now(); - future.await; - ctx.record_user_task_finished(UserTaskKind::WaitUntil, started_at.elapsed()); - ctx.reset_sleep_timer(); - }); + self.spawn_work(ActorWorkKind::WaitUntil, future); } #[cfg(not(feature = "wasm-runtime"))] pub fn register_task(&self, future: impl Future + Send + 'static) { - let ctx = self.clone(); - self.track_shutdown_task(async move { - Self::run_registered_task(ctx, future).await; - }); + self.spawn_work(ActorWorkKind::RegisteredTask, future); } #[cfg(feature = "wasm-runtime")] pub fn wait_until(&self, future: impl Future + 'static) { - let ctx = self.clone(); - self.track_shutdown_task(async move { - ctx.record_user_task_started(UserTaskKind::WaitUntil); - let started_at = Instant::now(); - future.await; - ctx.record_user_task_finished(UserTaskKind::WaitUntil, started_at.elapsed()); - ctx.reset_sleep_timer(); - }); + self.spawn_work(ActorWorkKind::WaitUntil, future); } #[cfg(feature = "wasm-runtime")] pub fn register_task(&self, future: impl Future + 'static) { - let ctx = self.clone(); - self.track_shutdown_task(async move { - Self::run_registered_task(ctx, future).await; - }); - } - - async fn run_registered_task(ctx: ActorContext, future: F) - where - F: Future, - { - let shutdown_deadline = ctx.shutdown_deadline_token(); - ctx.record_user_task_started(UserTaskKind::RegisteredTask); - let started_at = Instant::now(); - tokio::select! { - _ = future => {} - _ = shutdown_deadline.cancelled() => { - tracing::warn!( - actor_id = %ctx.actor_id(), - reason = "shutdown_deadline_elapsed", - "registered task cancelled by shutdown deadline" - ); - } - } - ctx.record_user_task_finished(UserTaskKind::RegisteredTask, started_at.elapsed()); + self.spawn_work(ActorWorkKind::RegisteredTask, future); } pub async fn keep_awake(&self, future: F) -> F::Output where F: Future, { - let _guard = self.keep_awake_guard(); - future.await + self.track_work(ActorWorkKind::KeepAwake, future).await } pub fn keep_awake_region(&self) -> KeepAwakeRegion { KeepAwakeRegion { - guard: Some(self.keep_awake_guard()), + region: Some(self.begin_work_region(ActorWorkKind::KeepAwake)), } } @@ -627,8 +576,8 @@ impl ActorContext { where F: Future, { - let _guard = self.internal_keep_awake_guard(); - future.await + self.track_work(ActorWorkKind::InternalKeepAwake, future) + .await } pub fn keep_awake_count(&self) -> usize { @@ -639,6 +588,36 @@ impl ActorContext { self.sleep_internal_keep_awake_count() } + pub async fn track_work(&self, kind: ActorWorkKind, future: F) -> F::Output + where + F: Future, + { + let _region = self.begin_work_region(kind); + future.await + } + + #[cfg(not(feature = "wasm-runtime"))] + pub fn spawn_work(&self, kind: ActorWorkKind, future: F) + where + F: Future + Send + 'static, + { + self.spawn_work_inner(kind, future); + } + + #[cfg(feature = "wasm-runtime")] + pub fn spawn_work(&self, kind: ActorWorkKind, future: F) + where + F: Future + 'static, + { + self.spawn_work_inner(kind, future); + } + + pub fn begin_work_region(&self, kind: ActorWorkKind) -> ActorWorkRegion { + ActorWorkRegion { + guard: Some(ActorWorkGuard::new(self.clone(), kind)), + } + } + pub fn actor_id(&self) -> &str { &self.0.actor_id } @@ -1252,7 +1231,7 @@ impl ActorContext { } pub(crate) fn pending_disconnect_count(&self) -> usize { - self.0.pending_disconnect_count.load(Ordering::SeqCst) + self.0.sleep.work.disconnect_callback.load() } pub async fn with_disconnect_callback(&self, run: F) -> T @@ -1260,8 +1239,8 @@ impl ActorContext { F: FnOnce() -> Fut, Fut: Future, { - let _guard = DisconnectCallbackGuard::new(self.clone()); - run().await + self.track_work(ActorWorkKind::DisconnectCallback, run()) + .await } pub(crate) fn configure_lifecycle_events(&self, sender: Option>) { @@ -1329,24 +1308,6 @@ impl ActorContext { self.0.sleep_requested.store(false, Ordering::SeqCst); } - fn keep_awake_guard(&self) -> KeepAwakeGuard { - let region = self - .keep_awake_region_state() - .with_log_fields("keep_awake", Some(self.actor_id().to_owned())); - let guard = KeepAwakeGuard::new(self.clone(), region); - self.reset_sleep_timer(); - guard - } - - fn internal_keep_awake_guard(&self) -> KeepAwakeGuard { - let region = self - .internal_keep_awake_region() - .with_log_fields("internal_keep_awake", Some(self.actor_id().to_owned())); - let guard = KeepAwakeGuard::new(self.clone(), region); - self.reset_sleep_timer(); - guard - } - pub(crate) async fn internal_keep_awake_task( &self, future: BoxFuture<'static, Result<()>>, @@ -1356,7 +1317,7 @@ impl ActorContext { pub fn websocket_callback_region(&self) -> WebSocketCallbackRegion { WebSocketCallbackRegion { - guard: Some(self.websocket_callback_guard(UserTaskKind::WebSocketCallback)), + region: Some(self.begin_work_region(ActorWorkKind::WebSocketCallback)), } } @@ -1369,11 +1330,25 @@ impl ActorContext { run().await } - fn websocket_callback_guard(&self, kind: UserTaskKind) -> WebSocketCallbackGuard { - let region = self.websocket_callback_region_state(); - self.record_user_task_started(kind); - self.reset_sleep_timer(); - WebSocketCallbackGuard::new(self.clone(), kind, region) + fn idle_work_region(&self, kind: ActorWorkKind) -> Option { + if !kind.policy().blocks_idle_sleep { + return None; + } + let region = match kind { + ActorWorkKind::KeepAwake => self.keep_awake_region_state(), + ActorWorkKind::InternalKeepAwake => self.internal_keep_awake_region(), + ActorWorkKind::WaitUntil => return None, + ActorWorkKind::RegisteredTask => return None, + ActorWorkKind::WebSocketCallback => self.websocket_callback_region_state(), + ActorWorkKind::DisconnectCallback => self.disconnect_callback_region_state(), + }; + Some(region.with_log_fields(kind.label(), Some(self.actor_id().to_owned()))) + } + + fn shutdown_work_region(&self) -> CountGuard { + let counter = self.0.sleep.work.shutdown_counter.clone(); + counter.increment(); + CountGuard::from_incremented(counter) } fn configure_sleep_hooks(&self) { @@ -1435,10 +1410,10 @@ impl ActorContext { self.cancel_scheduled_event(event_id); let ctx = self.clone(); let event_id = event_id.to_owned(); - let keep_awake_guard = self.internal_keep_awake_guard(); + let internal_keep_awake_region = self.begin_work_region(ActorWorkKind::InternalKeepAwake); self.track_shutdown_task(async move { - let _keep_awake_guard = keep_awake_guard; + let _internal_keep_awake_region = internal_keep_awake_region; ctx.record_user_task_started(UserTaskKind::ScheduledAction); let started_at = Instant::now(); let action_name = action.clone(); @@ -1524,49 +1499,65 @@ fn now_timestamp_ms() -> i64 { i64::try_from(duration.as_millis()).unwrap_or(i64::MAX) } -struct KeepAwakeGuard { +struct ActorWorkGuard { ctx: ActorContext, - region: Option, + kind: ActorWorkKind, + started_at: Option, + idle_region: Option, + shutdown_region: Option, } #[must_use] -struct DisconnectCallbackGuard { - ctx: ActorContext, - started_at: Instant, +pub struct ActorWorkRegion { + guard: Option, } -impl DisconnectCallbackGuard { - fn new(ctx: ActorContext) -> Self { - ctx.0 - .pending_disconnect_count - .fetch_add(1, Ordering::SeqCst); - ctx.record_user_task_started(UserTaskKind::DisconnectCallback); +impl ActorWorkGuard { + fn new(ctx: ActorContext, kind: ActorWorkKind) -> Self { + let policy = kind.policy(); + let idle_region = ctx.idle_work_region(kind); + let shutdown_region = if policy.drains_shutdown_grace { + Some(ctx.shutdown_work_region()) + } else { + None + }; + let started_at = if let Some(user_task_kind) = policy.user_task_kind { + ctx.record_user_task_started(user_task_kind); + Some(Instant::now()) + } else { + None + }; ctx.reset_sleep_timer(); Self { ctx, - started_at: Instant::now(), + kind, + started_at, + idle_region, + shutdown_region, } } } -impl Drop for DisconnectCallbackGuard { +impl Drop for ActorWorkGuard { fn drop(&mut self) { - let Ok(previous) = self.ctx.0.pending_disconnect_count.fetch_update( - Ordering::SeqCst, - Ordering::SeqCst, - |current| current.checked_sub(1), - ) else { - return; - }; - if previous == 0 { - return; + if let Some(started_at) = self.started_at.take() + && let Some(user_task_kind) = self.kind.policy().user_task_kind + { + self.ctx + .record_user_task_finished(user_task_kind, started_at.elapsed()); } - self.ctx - .record_user_task_finished(UserTaskKind::DisconnectCallback, self.started_at.elapsed()); + self.idle_region.take(); + self.shutdown_region.take(); self.ctx.reset_sleep_timer(); } } +impl Drop for ActorWorkRegion { + fn drop(&mut self) { + self.guard.take(); + } +} + #[must_use] #[derive(Debug)] pub(crate) struct InspectorAttachGuard { @@ -1616,67 +1607,24 @@ impl Drop for InspectorAttachGuard { } } -impl KeepAwakeGuard { - fn new(ctx: ActorContext, region: RegionGuard) -> Self { - Self { - ctx, - region: Some(region), - } - } -} - -impl Drop for KeepAwakeGuard { - fn drop(&mut self) { - self.region.take(); - self.ctx.reset_sleep_timer(); - } -} - -struct WebSocketCallbackGuard { - ctx: ActorContext, - kind: UserTaskKind, - started_at: Instant, - region: Option, -} - pub struct WebSocketCallbackRegion { - guard: Option, + region: Option, } pub struct KeepAwakeRegion { - guard: Option, -} - -impl WebSocketCallbackGuard { - fn new(ctx: ActorContext, kind: UserTaskKind, region: RegionGuard) -> Self { - Self { - ctx, - kind, - started_at: Instant::now(), - region: Some(region), - } - } -} - -impl Drop for WebSocketCallbackGuard { - fn drop(&mut self) { - self.ctx - .record_user_task_finished(self.kind, self.started_at.elapsed()); - self.region.take(); - self.ctx.reset_sleep_timer(); - } + region: Option, } impl Drop for WebSocketCallbackRegion { fn drop(&mut self) { - self.guard.take(); + self.region.take(); } } impl Drop for KeepAwakeRegion { fn drop(&mut self) { - // Take the guard explicitly to mirror WebSocketCallbackRegion. - self.guard.take(); + // Take the region explicitly to mirror WebSocketCallbackRegion. + self.region.take(); } } diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/mod.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/mod.rs index 82bcac5cc0..04855d127a 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/mod.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/mod.rs @@ -23,7 +23,7 @@ pub(crate) mod work_registry; pub use action::ActionDispatchError; pub use config::{ActionDefinition, ActorConfig, ActorConfigOverrides, CanHibernateWebSocket}; pub use connection::ConnHandle; -pub use context::{ActorContext, KeepAwakeRegion, WebSocketCallbackRegion}; +pub use context::{ActorContext, ActorWorkRegion, KeepAwakeRegion, WebSocketCallbackRegion}; pub use factory::{ActorEntryFn, ActorFactory}; pub use kv::Kv; pub use lifecycle_hooks::{ActorEvents, ActorStart, Reply}; @@ -41,3 +41,4 @@ pub use task::{ LifecycleEvent, LifecycleState, }; pub use task_types::{ActorChildOutcome, ShutdownKind, StateMutationReason, UserTaskKind}; +pub use work_registry::{ActorWorkKind, ActorWorkPolicy}; diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs index cc59acfd52..9144fc1648 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/sleep.rs @@ -13,16 +13,22 @@ use tokio::task::JoinHandle; use tracing::Instrument; use crate::actor::config::ActorConfig; +#[cfg(not(feature = "wasm-runtime"))] +use crate::actor::context::ActorWorkRegion; use crate::actor::context::ActorContext; use crate::actor::task_types::ShutdownKind; #[cfg(feature = "wasm-runtime")] use crate::actor::work_registry::LocalShutdownTask; -use crate::actor::work_registry::{CountGuard, RegionGuard, WorkRegistry}; +#[cfg(not(feature = "wasm-runtime"))] +use crate::actor::work_registry::ActorWorkPolicy; +use crate::actor::work_registry::{ + ActorWorkKind, CountGuard, RegionGuard, WorkRegistry, +}; #[cfg(feature = "wasm-runtime")] use crate::runtime::RuntimeSpawner; +use crate::time::{Instant, sleep}; #[cfg(test)] use crate::time::sleep_until; -use crate::time::{Instant, sleep}; #[cfg(test)] use crate::types::ActorKey; #[cfg(feature = "wasm-runtime")] @@ -113,6 +119,10 @@ impl std::fmt::Debug for SleepState { "websocket_callback_count", &self.work.websocket_callback.load(), ) + .field( + "disconnect_callback_count", + &self.work.disconnect_callback.load(), + ) .finish() } } @@ -412,6 +422,29 @@ impl ActorContext { } } + pub async fn wait_for_tracked_shutdown_work(&self) -> bool { + let shutdown_deadline = self.shutdown_deadline_token(); + tokio::select! { + _ = self.wait_for_tracked_shutdown_work_drained() => true, + _ = shutdown_deadline.cancelled() => false, + } + } + + async fn wait_for_tracked_shutdown_work_drained(&self) { + loop { + let shutdown_count = self.shutdown_task_count(); + let websocket_count = self.websocket_callback_count(); + if shutdown_count == 0 && websocket_count == 0 { + return; + } + + tokio::select! { + _ = self.0.sleep.work.shutdown_counter.wait_zero_unbounded(), if shutdown_count > 0 => {} + _ = self.0.sleep.work.websocket_callback.wait_zero_unbounded(), if websocket_count > 0 => {} + } + } + } + pub(crate) async fn wait_for_http_requests_drained(&self, deadline: Instant) -> bool { let Some(counter) = self.http_request_counter() else { return true; @@ -461,6 +494,135 @@ impl ActorContext { self.0.sleep.work.websocket_callback.load() } + pub(crate) fn disconnect_callback_region_state(&self) -> RegionGuard { + self.0.sleep.work.disconnect_callback_guard() + } + + #[cfg(not(feature = "wasm-runtime"))] + pub(crate) fn spawn_work_inner(&self, kind: ActorWorkKind, fut: F) -> bool + where + F: Future + Send + 'static, + { + if Handle::try_current().is_err() { + tracing::warn!(kind = kind.label(), "actor work spawned without tokio runtime"); + return false; + } + + let policy = kind.policy(); + if policy.aborts_at_shutdown_deadline { + let mut shutdown_tasks = self.0.sleep.work.shutdown_tasks.lock(); + if self.0.sleep.work.teardown_started.load(Ordering::Acquire) { + tracing::warn!(kind = kind.label(), "actor work spawned after teardown; aborting immediately"); + return false; + } + let region = self.begin_work_region(kind); + shutdown_tasks.spawn(self.build_spawned_work_task(kind, policy, region, fut)); + } else { + let mut unabortable_shutdown_tasks = + self.0.sleep.work.unabortable_shutdown_tasks.lock(); + if self.0.sleep.work.teardown_started.load(Ordering::Acquire) { + tracing::warn!(kind = kind.label(), "actor work spawned after teardown; aborting immediately"); + return false; + } + let region = self.begin_work_region(kind); + unabortable_shutdown_tasks + .spawn(self.build_spawned_work_task(kind, policy, region, fut)); + } + self.reset_sleep_timer(); + true + } + + #[cfg(not(feature = "wasm-runtime"))] + fn build_spawned_work_task( + &self, + kind: ActorWorkKind, + policy: ActorWorkPolicy, + region: ActorWorkRegion, + fut: F, + ) -> impl Future + Send + 'static + where + F: Future + Send + 'static, + { + let ctx = self.clone(); + async move { + let _region = region; + if policy.aborts_at_shutdown_deadline { + let shutdown_deadline = ctx.shutdown_deadline_token(); + tokio::select! { + _ = fut => {} + _ = shutdown_deadline.cancelled() => { + tracing::warn!( + actor_id = %ctx.actor_id(), + kind = kind.label(), + reason = "shutdown_deadline_elapsed", + "actor work cancelled by shutdown deadline" + ); + } + } + } else { + fut.await; + } + ctx.reset_sleep_timer(); + } + .in_current_span() + } + + #[cfg(feature = "wasm-runtime")] + pub(crate) fn spawn_work_inner(&self, kind: ActorWorkKind, fut: F) -> bool + where + F: Future + 'static, + { + let mut local_shutdown_tasks = self.0.sleep.work.local_shutdown_tasks.lock(); + if self.0.sleep.work.teardown_started.load(Ordering::Acquire) { + tracing::warn!(kind = kind.label(), "actor work spawned after teardown; aborting immediately"); + return false; + } + + let policy = kind.policy(); + let region = self.begin_work_region(kind); + let ctx = self.clone(); + let (complete_tx, complete_rx) = futures_oneshot::channel(); + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + local_shutdown_tasks.push(LocalShutdownTask { + abort_handle, + complete_rx, + aborts_at_shutdown_deadline: policy.aborts_at_shutdown_deadline, + }); + drop(local_shutdown_tasks); + let ctx_for_task = ctx.clone(); + wasm_bindgen_futures::spawn_local( + async move { + let task = async move { + let _region = region; + if policy.aborts_at_shutdown_deadline { + let shutdown_deadline = ctx_for_task.shutdown_deadline_token(); + tokio::select! { + _ = fut => {} + _ = shutdown_deadline.cancelled() => { + tracing::warn!( + actor_id = %ctx_for_task.actor_id(), + kind = kind.label(), + reason = "shutdown_deadline_elapsed", + "actor work cancelled by shutdown deadline" + ); + } + } + } else { + fut.await; + } + let _ = complete_tx.send(()); + ctx_for_task.reset_sleep_timer(); + }; + if Abortable::new(task, abort_registration).await.is_err() { + ctx.reset_sleep_timer(); + } + } + .in_current_span(), + ); + self.reset_sleep_timer(); + true + } + #[cfg(not(feature = "wasm-runtime"))] pub(crate) fn track_shutdown_task(&self, fut: F) -> bool where @@ -519,6 +681,7 @@ impl ActorContext { local_shutdown_tasks.push(LocalShutdownTask { abort_handle, complete_rx, + aborts_at_shutdown_deadline: true, }); drop(local_shutdown_tasks); let ctx_for_task = ctx.clone(); @@ -605,7 +768,9 @@ impl ActorContext { if abort_remaining { for task in local_shutdown_tasks { - task.abort_handle.abort(); + if task.aborts_at_shutdown_deadline { + task.abort_handle.abort(); + } if task.complete_rx.await.is_err() { tracing::debug!("aborted shutdown task during teardown"); } @@ -628,10 +793,12 @@ impl ActorContext { #[cfg(not(feature = "wasm-runtime"))] loop { - let mut shutdown_tasks = { + let mut abortable_shutdown_tasks = { let mut guard = self.0.sleep.work.shutdown_tasks.lock(); let taken = std::mem::take(&mut *guard); - if taken.is_empty() { + let mut unabortable_guard = self.0.sleep.work.unabortable_shutdown_tasks.lock(); + let unabortable_taken = std::mem::take(&mut *unabortable_guard); + if taken.is_empty() && unabortable_taken.is_empty() { self.0 .sleep .work @@ -639,18 +806,22 @@ impl ActorContext { .store(true, Ordering::Release); return; } - taken + (taken, unabortable_taken) }; - if abort_remaining { - shutdown_tasks.shutdown().await; - } else { - while let Some(result) = shutdown_tasks.join_next().await { - if let Err(error) = result - && !error.is_cancelled() - { - tracing::error!(?error, "shutdown task join failed during teardown"); - } + abortable_shutdown_tasks.0.shutdown().await; + while let Some(result) = abortable_shutdown_tasks.0.join_next().await { + if let Err(error) = result + && !error.is_cancelled() + { + tracing::error!(?error, "shutdown task join failed during teardown"); + } + } + while let Some(result) = abortable_shutdown_tasks.1.join_next().await { + if let Err(error) = result + && !error.is_cancelled() + { + tracing::error!(?error, "shutdown task join failed during teardown"); } } } diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/work_registry.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/work_registry.rs index bf281c1254..14bf3fa017 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/work_registry.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/work_registry.rs @@ -10,25 +10,127 @@ use rivet_envoy_client::async_counter::AsyncCounter; use tokio::sync::Notify; use tokio::task::JoinSet; +use crate::actor::task_types::UserTaskKind; + +/// Classifies actor work so sleep can apply one policy model across different APIs. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ActorWorkKind { + /// User work that keeps the actor out of idle sleep while it runs. + KeepAwake, + /// Runtime-owned work that should behave like keep-awake without exposing a user API. + InternalKeepAwake, + /// User work that may continue into sleep grace but should not block idle sleep. + WaitUntil, + /// Detached runtime task that drains during shutdown. + RegisteredTask, + /// Async WebSocket callback work that must hold sleep while a callback is running. + WebSocketCallback, + /// Disconnect callback work that must finish before sleep or destroy finalizes. + DisconnectCallback, +} + +/// Defines how a work kind participates in idle sleep and shutdown grace. +#[derive(Debug, Clone, Copy)] +pub struct ActorWorkPolicy { + /// True when active work should prevent the actor from entering idle sleep. + pub blocks_idle_sleep: bool, + /// True when active work should delay sleep-grace runtime cleanup. + pub drains_shutdown_grace: bool, + /// True when detached work should be cancelled after the shutdown deadline. + pub aborts_at_shutdown_deadline: bool, + /// User-facing task kind used for metrics and lifecycle diagnostics. + pub user_task_kind: Option, +} + +impl ActorWorkKind { + /// Returns the lifecycle policy owned by this work kind. + pub fn policy(self) -> ActorWorkPolicy { + match self { + ActorWorkKind::KeepAwake => ActorWorkPolicy { + blocks_idle_sleep: true, + drains_shutdown_grace: true, + aborts_at_shutdown_deadline: true, + user_task_kind: None, + }, + ActorWorkKind::InternalKeepAwake => ActorWorkPolicy { + blocks_idle_sleep: true, + drains_shutdown_grace: true, + aborts_at_shutdown_deadline: false, + user_task_kind: None, + }, + ActorWorkKind::WaitUntil => ActorWorkPolicy { + blocks_idle_sleep: false, + drains_shutdown_grace: true, + aborts_at_shutdown_deadline: true, + user_task_kind: Some(UserTaskKind::WaitUntil), + }, + ActorWorkKind::RegisteredTask => ActorWorkPolicy { + blocks_idle_sleep: false, + drains_shutdown_grace: true, + aborts_at_shutdown_deadline: true, + user_task_kind: Some(UserTaskKind::RegisteredTask), + }, + ActorWorkKind::WebSocketCallback => ActorWorkPolicy { + blocks_idle_sleep: true, + drains_shutdown_grace: true, + aborts_at_shutdown_deadline: true, + user_task_kind: Some(UserTaskKind::WebSocketCallback), + }, + ActorWorkKind::DisconnectCallback => ActorWorkPolicy { + blocks_idle_sleep: true, + drains_shutdown_grace: true, + aborts_at_shutdown_deadline: true, + user_task_kind: Some(UserTaskKind::DisconnectCallback), + }, + } + } + + /// Returns a stable label for logs and metric fields. + pub(crate) fn label(self) -> &'static str { + match self { + ActorWorkKind::KeepAwake => "keep_awake", + ActorWorkKind::InternalKeepAwake => "internal_keep_awake", + ActorWorkKind::WaitUntil => "wait_until", + ActorWorkKind::RegisteredTask => "registered_task", + ActorWorkKind::WebSocketCallback => "websocket_callback", + ActorWorkKind::DisconnectCallback => "disconnect_callback", + } + } +} + +/// Holds per-kind counters and task sets used by actor sleep and shutdown. pub(crate) struct WorkRegistry { + /// Counts user keep-awake regions that block idle sleep. pub(crate) keep_awake: Arc, + /// Counts runtime-owned keep-awake regions that block idle sleep. pub(crate) internal_keep_awake: Arc, + /// Counts async WebSocket callbacks currently running. pub(crate) websocket_callback: Arc, + /// Counts disconnect callbacks currently running. + pub(crate) disconnect_callback: Arc, + /// Counts work that must drain before sleep-grace runtime cleanup. pub(crate) shutdown_counter: Arc, + /// Counts lifecycle hooks dispatched by core into the actor runtime. pub(crate) core_dispatched_hooks: Arc, // Forced-sync: shutdown tasks are inserted from sync paths and moved out // before awaiting shutdown. + /// Detached shutdown work that can be aborted during final teardown. pub(crate) shutdown_tasks: Mutex>, + /// Detached shutdown work that must be joined even after the grace deadline. + pub(crate) unabortable_shutdown_tasks: Mutex>, #[cfg(feature = "wasm-runtime")] + /// Wasm-local shutdown tasks tracked by completion channel and abort handle. pub(crate) local_shutdown_tasks: Mutex>, + /// Wakes sleep waiters when core-owned idle blockers reach zero. pub(crate) idle_notify: Arc, /// Woken on every transition of a sleep-affecting counter that is not - /// otherwise guarded by `KeepAwakeGuard` / `WebSocketCallbackGuard` / - /// `DisconnectCallbackGuard`. In practice this covers externally-owned - /// counters like the envoy HTTP request counter whose increments happen - /// outside rivetkit-core. + /// otherwise guarded by `ActorWorkRegion`. In practice this covers + /// externally-owned counters like the envoy HTTP request counter whose + /// increments happen outside rivetkit-core. pub(crate) activity_notify: Arc, + /// Set once final teardown starts so new detached work is refused. pub(crate) teardown_started: AtomicBool, + /// Set when the grace deadline has elapsed and abortable work should be cancelled. pub(crate) shutdown_deadline_reached: AtomicBool, } @@ -36,9 +138,11 @@ pub(crate) struct WorkRegistry { pub(crate) struct LocalShutdownTask { pub(crate) abort_handle: AbortHandle, pub(crate) complete_rx: futures_oneshot::Receiver<()>, + pub(crate) aborts_at_shutdown_deadline: bool, } impl WorkRegistry { + /// Creates an empty registry and wires idle notifications for idle-blocking counters. pub(crate) fn new() -> Self { let idle_notify = Arc::new(Notify::new()); let keep_awake = Arc::new(AsyncCounter::new()); @@ -47,14 +151,18 @@ impl WorkRegistry { internal_keep_awake.register_zero_notify(&idle_notify); let websocket_callback = Arc::new(AsyncCounter::new()); websocket_callback.register_zero_notify(&idle_notify); + let disconnect_callback = Arc::new(AsyncCounter::new()); + disconnect_callback.register_zero_notify(&idle_notify); Self { keep_awake, internal_keep_awake, websocket_callback, + disconnect_callback, shutdown_counter: Arc::new(AsyncCounter::new()), core_dispatched_hooks: Arc::new(AsyncCounter::new()), shutdown_tasks: Mutex::new(JoinSet::new()), + unabortable_shutdown_tasks: Mutex::new(JoinSet::new()), #[cfg(feature = "wasm-runtime")] local_shutdown_tasks: Mutex::new(Vec::new()), idle_notify, @@ -64,17 +172,25 @@ impl WorkRegistry { } } + /// Starts a user keep-awake region. pub(crate) fn keep_awake_guard(&self) -> RegionGuard { RegionGuard::new(self.keep_awake.clone()) } + /// Starts a runtime-owned keep-awake region. pub(crate) fn internal_keep_awake_guard(&self) -> RegionGuard { RegionGuard::new(self.internal_keep_awake.clone()) } + /// Starts an async WebSocket callback region. pub(crate) fn websocket_callback_guard(&self) -> RegionGuard { RegionGuard::new(self.websocket_callback.clone()) } + + /// Starts a disconnect callback region. + pub(crate) fn disconnect_callback_guard(&self) -> RegionGuard { + RegionGuard::new(self.disconnect_callback.clone()) + } } impl Default for WorkRegistry { @@ -83,6 +199,7 @@ impl Default for WorkRegistry { } } +/// RAII guard that decrements an actor work counter when dropped. pub(crate) struct RegionGuard { counter: Arc, log_kind: Option<&'static str>, @@ -90,6 +207,7 @@ pub(crate) struct RegionGuard { } impl RegionGuard { + /// Increments a counter and returns a guard that will decrement it. fn new(counter: Arc) -> Self { counter.increment(); Self { @@ -99,6 +217,7 @@ impl RegionGuard { } } + /// Wraps a counter that has already been incremented. pub(crate) fn from_incremented(counter: Arc) -> Self { Self { counter, @@ -107,6 +226,7 @@ impl RegionGuard { } } + /// Enables paired debug logs for the lifetime of this guard. pub(crate) fn with_log_fields(mut self, kind: &'static str, actor_id: Option) -> Self { let count = self.counter.load(); match actor_id.as_deref() { diff --git a/rivetkit-rust/packages/rivetkit-core/src/lib.rs b/rivetkit-rust/packages/rivetkit-core/src/lib.rs index f8392f6137..76e4bce759 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/lib.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/lib.rs @@ -118,7 +118,7 @@ pub use actor::config::{ ActionDefinition, ActorConfig, ActorConfigInput, ActorConfigOverrides, CanHibernateWebSocket, }; pub use actor::connection::ConnHandle; -pub use actor::context::{ActorContext, KeepAwakeRegion, WebSocketCallbackRegion}; +pub use actor::context::{ActorContext, ActorWorkRegion, KeepAwakeRegion, WebSocketCallbackRegion}; pub use actor::factory::{ActorEntryFn, ActorFactory}; pub use actor::kv::Kv; pub use actor::lifecycle_hooks::{ActorEvents, ActorStart, Reply}; @@ -139,6 +139,7 @@ pub use actor::task::{ LifecycleEvent, LifecycleState, }; pub use actor::task_types::ShutdownKind; +pub use actor::work_registry::{ActorWorkKind, ActorWorkPolicy}; pub use error::ActorLifecycle; pub use inspector::{Inspector, InspectorSnapshot}; pub use registry::{CoreRegistry, ServeConfig}; diff --git a/rivetkit-rust/packages/rivetkit-core/tests/sleep.rs b/rivetkit-rust/packages/rivetkit-core/tests/sleep.rs index dea65ec9df..326bf8eef2 100644 --- a/rivetkit-rust/packages/rivetkit-core/tests/sleep.rs +++ b/rivetkit-rust/packages/rivetkit-core/tests/sleep.rs @@ -3,6 +3,7 @@ mod moved_tests { use std::sync::atomic::{AtomicUsize, Ordering}; use crate::actor::context::ActorContext; + use crate::actor::work_registry::ActorWorkKind; use parking_lot::Mutex as DropMutex; use rivet_envoy_client::async_counter::AsyncCounter; use std::time::{Duration, Instant}; @@ -20,6 +21,7 @@ mod moved_tests { struct MessageVisitor { message: Option, actor_id: Option, + kind: Option, reason: Option, } @@ -28,6 +30,7 @@ mod moved_tests { match field.name() { "message" => self.message = Some(value.to_owned()), "actor_id" => self.actor_id = Some(value.to_owned()), + "kind" => self.kind = Some(value.to_owned()), "reason" => self.reason = Some(value.to_owned()), _ => {} } @@ -38,6 +41,7 @@ mod moved_tests { match field.name() { "message" => self.message = Some(value), "actor_id" => self.actor_id = Some(value), + "kind" => self.kind = Some(value), "reason" => self.reason = Some(value), _ => {} } @@ -84,8 +88,9 @@ mod moved_tests { let mut visitor = MessageVisitor::default(); event.record(&mut visitor); - if visitor.message.as_deref() == Some("registered task cancelled by shutdown deadline") + if visitor.message.as_deref() == Some("actor work cancelled by shutdown deadline") && visitor.actor_id.as_deref() == Some("actor-register-task-deadline") + && visitor.kind.as_deref() == Some("registered_task") && visitor.reason.as_deref() == Some("shutdown_deadline_elapsed") { self.count.fetch_add(1, Ordering::SeqCst); @@ -223,6 +228,88 @@ mod moved_tests { assert_eq!(warning_count.load(Ordering::SeqCst), 1); } + #[tokio::test(start_paused = true)] + async fn tracked_shutdown_work_drain_wakes_on_shutdown_counter_zero() { + let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-drain-counter"); + ctx.notify_activity_dirty(); + let (release_tx, release_rx) = oneshot::channel(); + + ctx.track_shutdown_task(async move { + let _ = release_rx.await; + }); + let waiter = tokio::spawn({ + let ctx = ctx.clone(); + async move { ctx.wait_for_tracked_shutdown_work().await } + }); + + yield_now().await; + assert!( + !waiter.is_finished(), + "shutdown drain should wait while the counter is non-zero" + ); + + release_tx + .send(()) + .expect("release signal should send to tracked shutdown task"); + yield_now().await; + yield_now().await; + + assert!( + waiter.is_finished(), + "shutdown drain should wake from the counter zero notification" + ); + assert!(waiter.await.expect("shutdown drain waiter should join")); + } + + #[tokio::test(start_paused = true)] + async fn tracked_shutdown_work_drain_wakes_on_websocket_callback_zero() { + let ctx = ActorContext::new_for_sleep_tests("actor-shutdown-drain-websocket"); + ctx.notify_activity_dirty(); + let guard = ctx.websocket_callback_region(); + let waiter = tokio::spawn({ + let ctx = ctx.clone(); + async move { ctx.wait_for_tracked_shutdown_work().await } + }); + + yield_now().await; + assert!( + !waiter.is_finished(), + "shutdown drain should wait while the websocket callback is active" + ); + + drop(guard); + yield_now().await; + + assert!( + waiter.is_finished(), + "shutdown drain should wake from the websocket zero notification" + ); + assert!(waiter.await.expect("shutdown drain waiter should join")); + } + + #[tokio::test(start_paused = true)] + async fn keep_awake_spawned_work_exits_when_shutdown_deadline_cancels() { + let ctx = ActorContext::new_for_sleep_tests("actor-keep-awake-deadline"); + + ctx.spawn_work(ActorWorkKind::KeepAwake, futures::future::pending::<()>()); + assert_eq!(ctx.shutdown_task_count(), 1); + assert_eq!(ctx.sleep_keep_awake_count(), 1); + + ctx.cancel_shutdown_deadline(); + + assert!( + ctx.0 + .sleep + .work + .shutdown_counter + .wait_zero(Instant::now() + Duration::from_millis(1)) + .await, + "keepAwake work should stop waiting after the shutdown deadline" + ); + assert_eq!(ctx.shutdown_task_count(), 0); + assert_eq!(ctx.sleep_keep_awake_count(), 0); + } + #[tokio::test(start_paused = true)] async fn sleep_then_destroy_signal_tasks_do_not_leak_after_teardown() { let ctx = ActorContext::new_for_sleep_tests("actor-sleep-destroy"); diff --git a/rivetkit-typescript/packages/rivetkit-napi/index.d.ts b/rivetkit-typescript/packages/rivetkit-napi/index.d.ts index e89dc595d3..2cba08a1c7 100644 --- a/rivetkit-typescript/packages/rivetkit-napi/index.d.ts +++ b/rivetkit-typescript/packages/rivetkit-napi/index.d.ts @@ -240,6 +240,7 @@ export declare class ActorContext { disconnectConns(predicate: (...args: any[]) => any): Promise broadcast(name: string, args: Buffer): void waitUntil(promise: Promise): void + waitForTrackedShutdownWork(): Promise registerTask(promise: Promise): void runtimeState(): object clearRuntimeState(): void diff --git a/rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs b/rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs index ac78a60867..f17746f989 100644 --- a/rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs +++ b/rivetkit-typescript/packages/rivetkit-napi/src/actor_context.rs @@ -18,7 +18,7 @@ use napi_derive::napi; use parking_lot::Mutex; use rivetkit_core::types::ActorKeySegment; use rivetkit_core::{ - ActorContext as CoreActorContext, ConnHandle as CoreConnHandle, KeepAwakeRegion, + ActorContext as CoreActorContext, ActorWorkKind, ConnHandle as CoreConnHandle, KeepAwakeRegion, Request as CoreRequest, RequestSaveOpts, StateDelta, WebSocketCallbackRegion, }; use scc::HashMap as SccHashMap; @@ -480,9 +480,7 @@ impl ActorContext { #[napi] pub fn keep_awake(&self, promise: Promise) -> napi::Result<()> { - let region = self.inner.keep_awake_region(); - self.inner.wait_until(async move { - let _region = region; + self.inner.spawn_work(ActorWorkKind::KeepAwake, async move { if let Err(error) = promise.await { tracing::warn!(?error, "actor keep_awake promise rejected"); } @@ -609,6 +607,11 @@ impl ActorContext { Ok(()) } + #[napi] + pub async fn wait_for_tracked_shutdown_work(&self) -> bool { + self.inner.wait_for_tracked_shutdown_work().await + } + #[napi] pub fn register_task(&self, promise: Promise) -> napi::Result<()> { self.shared diff --git a/rivetkit-typescript/packages/rivetkit-wasm/src/lib.rs b/rivetkit-typescript/packages/rivetkit-wasm/src/lib.rs index 6411f75c47..66e8a166bd 100644 --- a/rivetkit-typescript/packages/rivetkit-wasm/src/lib.rs +++ b/rivetkit-typescript/packages/rivetkit-wasm/src/lib.rs @@ -12,6 +12,7 @@ use rivetkit_core::error::public_error_status_code; use rivetkit_core::inspector::InspectorAuth; use rivetkit_core::{ ActorConfig, ActorConfigInput, ActorEvent, ActorFactory as CoreActorFactory, ActorStart, + ActorWorkKind, BindParam, ColumnValue, CoreRegistry as NativeCoreRegistry, CoreServerlessRuntime, EnqueueAndWaitOpts, KeepAwakeRegion, ListOpts, QueueMessage, QueueNextBatchOpts, QueueSendResult, QueueSendStatus, QueueTryNextBatchOpts, QueueWaitOpts, Request, @@ -1405,13 +1406,16 @@ impl WasmActorContext { }); } + #[wasm_bindgen(js_name = waitForTrackedShutdownWork)] + pub async fn wait_for_tracked_shutdown_work(&self) -> bool { + self.inner.wait_for_tracked_shutdown_work().await + } + #[wasm_bindgen(js_name = keepAwake)] pub fn keep_awake(&self, promise: Promise) { console_error("keepAwake binding is deprecated; use beginKeepAwake/endKeepAwake"); - let region = self.inner.keep_awake_region(); let actor_id = self.inner.actor_id().to_owned(); - self.inner.register_task(async move { - let _region = region; + self.inner.spawn_work(ActorWorkKind::KeepAwake, async move { if let Err(error) = JsFuture::from(promise).await { console_error(&format!( "actor keepAwake promise rejected for actor {actor_id}: {}", diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts index 832e1065ee..59fcadf4a2 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry-static.ts @@ -113,6 +113,8 @@ import { sleepRawWsDelayedSendOnSleep, sleepWithWaitUntilInOnWake, sleepAbortListenerVarsActor, + sleepWaitUntilVarsDuringGrace, + sleepRawWsVarsExceedsGrace, } from "./sleep"; import { sleepWithDb, @@ -210,6 +212,8 @@ export const registry = setup({ sleepRawWsDelayedSendOnSleep, sleepWithWaitUntilInOnWake, sleepAbortListenerVarsActor, + sleepWaitUntilVarsDuringGrace, + sleepRawWsVarsExceedsGrace, counterWaitUntilProbe, // From sleep-db.ts sleepWithDb, diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/sleep.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/sleep.ts index 56fb3aa865..ab48d53829 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/sleep.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/sleep.ts @@ -618,3 +618,115 @@ export const sleepWithNoSleepOption = actor({ noSleep: true, }, }); + +export const WAIT_UNTIL_GRACE_DELAY = 150; +export const WAIT_UNTIL_GRACE_PERIOD = 1000; +export const WAIT_UNTIL_GRACE_SLEEP_TIMEOUT = 100; + +export const sleepWaitUntilVarsDuringGrace = actor({ + state: { + startCount: 0, + sleepCount: 0, + waitUntilStarted: 0, + }, + createVars: () => ({ + dirty: false, + }), + onWake: (c) => { + c.state.startCount += 1; + }, + onSleep: (c) => { + c.state.sleepCount += 1; + c.waitUntil( + (async () => { + c.state.waitUntilStarted += 1; + await new Promise((resolve) => + setTimeout(resolve, WAIT_UNTIL_GRACE_DELAY), + ); + c.vars.dirty = true; + })(), + ); + }, + actions: { + triggerSleep: (c) => { + c.sleep(); + }, + getStatus: (c) => ({ + startCount: c.state.startCount, + sleepCount: c.state.sleepCount, + waitUntilStarted: c.state.waitUntilStarted, + }), + }, + options: { + sleepTimeout: WAIT_UNTIL_GRACE_SLEEP_TIMEOUT, + sleepGracePeriod: WAIT_UNTIL_GRACE_PERIOD, + }, +}); + +// Reproduces a production crash where c.vars becomes undefined after the +// grace deadline expires and clearNativeRuntimeState unrefs the NAPI +// runtime state object. An async message handler accesses c.vars after an +// await that outlasts the grace period. +// +// The close-handler variant cannot reproduce the bug because the tracked +// websocket callback region blocks can_arm_sleep_timer. Instead we use a +// message handler that starts slow async work, then the actor is told to +// sleep programmatically while the handler is still running. +export const VARS_EXCEEDS_GRACE_DELAY = 2000; +export const VARS_EXCEEDS_GRACE_PERIOD = 200; +export const VARS_EXCEEDS_GRACE_SLEEP_TIMEOUT = 100; + +export const sleepRawWsVarsExceedsGrace = actor({ + state: { + startCount: 0, + sleepCount: 0, + handlerStarted: 0, + handlerFinished: 0, + }, + createVars: () => ({ + dirty: false, + }), + onWake: (c) => { + c.state.startCount += 1; + }, + onSleep: (c) => { + c.state.sleepCount += 1; + }, + onWebSocket: (c, websocket: UniversalWebSocket) => { + websocket.addEventListener("message", async (event: any) => { + if (event.data !== "slow-vars-work") return; + + c.state.handlerStarted += 1; + websocket.send(JSON.stringify({ type: "started" })); + + // Wait longer than the grace period so the runtime state + // gets cleared while this handler is still running. + await new Promise((resolve) => + setTimeout(resolve, VARS_EXCEEDS_GRACE_DELAY), + ); + // This c.vars access crashes with TypeError in prod because + // the NAPI runtime state reference has been unreffed. + // Do NOT wrap in try/catch: c.state also breaks after cleanup, + // so the error needs to propagate to the process level. + c.vars.dirty = true; + c.state.handlerFinished += 1; + }); + + websocket.send(JSON.stringify({ type: "connected" })); + }, + actions: { + triggerSleep: (c) => { + c.sleep(); + }, + getStatus: (c) => ({ + startCount: c.state.startCount, + sleepCount: c.state.sleepCount, + handlerStarted: c.state.handlerStarted, + handlerFinished: c.state.handlerFinished, + }), + }, + options: { + sleepTimeout: VARS_EXCEEDS_GRACE_SLEEP_TIMEOUT, + sleepGracePeriod: VARS_EXCEEDS_GRACE_PERIOD, + }, +}); diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/napi-runtime.ts b/rivetkit-typescript/packages/rivetkit/src/registry/napi-runtime.ts index 06b3748342..f1b8890a34 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/napi-runtime.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/napi-runtime.ts @@ -414,6 +414,12 @@ export class NapiCoreRuntime implements CoreRuntime { asNativeActorContext(ctx).waitUntil(promise); } + async actorWaitForTrackedShutdownWork( + ctx: ActorContextHandle, + ): Promise { + return await asNativeActorContext(ctx).waitForTrackedShutdownWork(); + } + actorKeepAwake(ctx: ActorContextHandle, promise: Promise): void { asNativeActorContext(ctx).keepAwake(promise); } diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/native.ts b/rivetkit-typescript/packages/rivetkit/src/registry/native.ts index 633bddd9df..d00963729f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/native.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/native.ts @@ -269,9 +269,6 @@ type NativeDatabaseClientState = { type NativeActorRuntimeState = { sql?: ReturnType; databaseClient?: NativeDatabaseClientState; - keepAwakeCount?: number; - deferSleepCleanupUntilKeepAwakeIdle?: boolean; - deferredSleepCleanupActorCtx?: ActorContextHandleAdapter; varsInitialized?: boolean; vars?: unknown; destroyGate?: NativeDestroyGate; @@ -424,26 +421,12 @@ async function cleanupNativeSleepRuntimeState( runtime: CoreRuntime, ctx: ActorContextHandle, ): Promise { + await runtime.actorWaitForTrackedShutdownWork(ctx); await closeNativeDatabaseClient(runtime, ctx); await closeNativeSqlDatabase(runtime, ctx); clearNativeRuntimeState(runtime, ctx); } -async function cleanupDeferredNativeSleepRuntimeState( - runtime: CoreRuntime, - ctx: ActorContextHandle, - runtimeState: NativeActorRuntimeState, -): Promise { - runtimeState.deferSleepCleanupUntilKeepAwakeIdle = false; - const actorCtx = runtimeState.deferredSleepCleanupActorCtx; - runtimeState.deferredSleepCleanupActorCtx = undefined; - try { - await cleanupNativeSleepRuntimeState(runtime, ctx); - } finally { - await actorCtx?.dispose(); - } -} - function closeNativeSqlDatabase( runtime: CoreRuntime, ctx: ActorContextHandle, @@ -2807,11 +2790,6 @@ export class ActorContextHandleAdapter { } keepAwake(promise: Promise): Promise { - const runtimeState = getNativeRuntimeState(this.#runtime, this.#ctx); - // Increment before native registration so sleep cleanup observes JS work - // even if the promise settles immediately. - runtimeState.keepAwakeCount = (runtimeState.keepAwakeCount ?? 0) + 1; - let registered = false; const trackedPromise = Promise.resolve(promise) .catch((error) => { logger().warn({ @@ -2819,36 +2797,12 @@ export class ActorContextHandleAdapter { error: stringifyError(error), }); }) - .finally(async () => { - if (!registered) { - return; - } - runtimeState.keepAwakeCount = Math.max( - (runtimeState.keepAwakeCount ?? 1) - 1, - 0, - ); - if ( - runtimeState.keepAwakeCount === 0 && - runtimeState.deferSleepCleanupUntilKeepAwakeIdle - ) { - await cleanupDeferredNativeSleepRuntimeState( - this.#runtime, - this.#ctx, - runtimeState, - ); - } - }) .then(() => null); try { callNativeSync(() => this.#runtime.actorKeepAwake(this.#ctx, trackedPromise), ); - registered = true; } catch (error) { - runtimeState.keepAwakeCount = Math.max( - (runtimeState.keepAwakeCount ?? 1) - 1, - 0, - ); if (!isClosedTaskRegistrationError(error)) { throw error; } @@ -4013,12 +3967,9 @@ export function buildNativeFactory( } } } finally { - const runtimeState = getNativeRuntimeState(runtime, ctx); - if ((runtimeState.keepAwakeCount ?? 0) > 0) { - runtimeState.deferSleepCleanupUntilKeepAwakeIdle = true; - runtimeState.deferredSleepCleanupActorCtx = actorCtx; - } else { + try { await cleanupNativeSleepRuntimeState(runtime, ctx); + } finally { await actorCtx.dispose(); } } diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/runtime.ts b/rivetkit-typescript/packages/rivetkit/src/registry/runtime.ts index 46cfa4bbfe..0e4875e966 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/runtime.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/runtime.ts @@ -388,6 +388,7 @@ export interface CoreRuntime { args: RuntimeBytes, ): void; actorWaitUntil(ctx: ActorContextHandle, promise: Promise): void; + actorWaitForTrackedShutdownWork(ctx: ActorContextHandle): Promise; actorKeepAwake(ctx: ActorContextHandle, promise: Promise): void; actorBeginKeepAwake(ctx: ActorContextHandle): number; actorEndKeepAwake(ctx: ActorContextHandle, regionId: number): void; diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/wasm-runtime.ts b/rivetkit-typescript/packages/rivetkit/src/registry/wasm-runtime.ts index b970ee22ee..dbb9028dbf 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/wasm-runtime.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/wasm-runtime.ts @@ -499,6 +499,15 @@ export class WasmCoreRuntime implements CoreRuntime { callHandle(asWasmActorContext(ctx), "waitUntil", promise); } + async actorWaitForTrackedShutdownWork( + ctx: ActorContextHandle, + ): Promise { + return await callHandle>( + asWasmActorContext(ctx), + "waitForTrackedShutdownWork", + ); + } + actorKeepAwake(ctx: ActorContextHandle, promise: Promise): void { const wasmCtx = asWasmActorContext(ctx); const regionId = callHandle(wasmCtx, "beginKeepAwake"); diff --git a/rivetkit-typescript/packages/rivetkit/tests/driver/actor-sleep.test.ts b/rivetkit-typescript/packages/rivetkit/tests/driver/actor-sleep.test.ts index 28fc0a76d6..49aa57bab8 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/driver/actor-sleep.test.ts +++ b/rivetkit-typescript/packages/rivetkit/tests/driver/actor-sleep.test.ts @@ -5,6 +5,10 @@ import { RAW_WS_HANDLER_DELAY, RAW_WS_HANDLER_SLEEP_TIMEOUT, SLEEP_TIMEOUT, + VARS_EXCEEDS_GRACE_DELAY, + VARS_EXCEEDS_GRACE_SLEEP_TIMEOUT, + WAIT_UNTIL_GRACE_DELAY, + WAIT_UNTIL_GRACE_SLEEP_TIMEOUT, } from "../../fixtures/driver-test-suite/sleep"; import { describeDriverMatrix } from "./shared-matrix"; import { setupDriverTest, waitFor } from "./shared-utils"; @@ -962,5 +966,104 @@ describeDriverMatrix("Actor Sleep", (driverTestConfig) => { expect(startCount).toBe(2); } }); + + test( + "waitUntil in onSleep keeps c.vars available during grace", + async (c) => { + const { client, getRuntimeOutput } = await setupDriverTest( + c, + driverTestConfig, + ); + + const actor = + client.sleepWaitUntilVarsDuringGrace.getOrCreate([ + "waituntil-vars-during-grace", + ]); + + await actor.triggerSleep(); + await waitFor( + driverTestConfig, + WAIT_UNTIL_GRACE_DELAY + + WAIT_UNTIL_GRACE_SLEEP_TIMEOUT + + 500, + ); + + const status = await actor.getStatus(); + expect(status.sleepCount).toBe(1); + expect(status.waitUntilStarted).toBe(1); + const output = getRuntimeOutput(); + expect(output).not.toContain( + "Cannot set properties of undefined", + ); + expect(output).not.toContain( + "Cannot read properties of undefined", + ); + }, + { timeout: 10_000 }, + ); + + test( + "c.vars access in ws handler should not crash after grace deadline", + async (c) => { + const { client, getRuntimeOutput } = await setupDriverTest( + c, + driverTestConfig, + ); + + const actor = + client.sleepRawWsVarsExceedsGrace.getOrCreate([ + "ws-vars-exceeds-grace", + ]); + const ws = await connectRawWebSocket(actor); + + // Send a message that starts slow async work (2000ms delay + // before accessing c.vars). + ws.send("slow-vars-work"); + + // Wait for the handler to confirm it started. + await new Promise((resolve) => { + const onMessage = (event: MessageEvent) => { + const data = JSON.parse(String(event.data)); + if (data.type === "started") { + ws.removeEventListener("message", onMessage); + resolve(); + } + }; + ws.addEventListener("message", onMessage); + }); + + // Trigger sleep while the handler is still doing slow work. + // The grace period (200ms) is much shorter than the handler + // delay (2000ms), so onSleep will clear the runtime state + // while the handler is still running. + await actor.triggerSleep(); + + // Wait for the handler to finish and the actor to complete + // its sleep cycle. + await waitFor( + driverTestConfig, + VARS_EXCEEDS_GRACE_DELAY + + VARS_EXCEEDS_GRACE_SLEEP_TIMEOUT + + 500, + ); + + // Wake the actor and check what happened. + const status = await actor.getStatus(); + expect(status.sleepCount).toBeGreaterThanOrEqual(1); + expect(status.handlerStarted).toBe(1); + // The runtime must not crash with TypeError when the + // handler accesses c.vars after the grace deadline. + // Core-tracked shutdown work keeps the runtime state alive + // until the websocket callback work drains. + const output = getRuntimeOutput(); + expect(output).not.toContain( + "Cannot set properties of undefined", + ); + expect(output).not.toContain( + "Cannot read properties of undefined", + ); + }, + { timeout: 15_000 }, + ); }); });