diff --git a/src/pool.rs b/src/pool.rs index 7a09a3e..20987b7 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -286,7 +286,7 @@ impl PoolInner { request = self.rx.recv() => { match request { Some(Request::Claim { id, tx }) => { - self.claim_or_enqueue(id, tx).await + self.claim_or_enqueue(id, tx) } // The caller has explicitly asked us to terminate, and // we should respond to them once we've stopped doing @@ -326,28 +326,28 @@ impl PoolInner { // Periodically rebalance the allocation of slots to backends _ = rebalance_interval.tick() => { event!(Level::INFO, "Rebalancing: timer tick"); - self.rebalance().await; + self.rebalance(); } // If any of the slots change state, update their allocations. Some((name, status)) = &mut backend_status_stream.next(), if !backend_status_stream.is_empty() => { event!(Level::INFO, name = ?name, status = ?status, "Rebalancing: Backend has new status"); rebalance_interval.reset(); - self.rebalance().await; + self.rebalance(); if matches!(status, slot::SetState::Online { has_unclaimed_slots: true }) { - self.try_claim_from_queue().await; + self.try_claim_from_queue(); } }, } } } - async fn claim_or_enqueue( + fn claim_or_enqueue( &mut self, id: ClaimId, tx: oneshot::Sender, Error>>, ) { - let result = self.claim(id).await; + let result = self.claim(id); if result.is_ok() { let _ = tx.send(result); return; @@ -364,13 +364,13 @@ impl PoolInner { }); } - async fn try_claim_from_queue(&mut self) { + fn try_claim_from_queue(&mut self) { loop { let Some(request) = self.request_queue.pop_front() else { return; }; - let result = self.claim(request.id).await; + let result = self.claim(request.id); if result.is_ok() { let _ = request.tx.send(result); } else { @@ -394,16 +394,16 @@ impl PoolInner { } #[instrument(skip(self), name = "PoolInner::rebalance")] - async fn rebalance(&mut self) { + fn rebalance(&mut self) { #[cfg(feature = "probes")] probes::rebalance__start!(|| self.name.as_str()); - self.rebalance_inner().await; + self.rebalance_inner(); #[cfg(feature = "probes")] probes::rebalance__done!(|| self.name.as_str()); } - async fn rebalance_inner(&mut self) { + fn rebalance_inner(&mut self) { let mut questionable_backend_count = 0; let mut usable_backends = vec![]; @@ -412,7 +412,7 @@ impl PoolInner { for (name, slot_set) in iter { match slot_set.get_state() { slot::SetState::Offline => { - let _ = slot_set.set_wanted_count(1).await; + let _ = slot_set.set_wanted_count(1); questionable_backend_count += 1; } slot::SetState::Online { .. } => { @@ -442,7 +442,7 @@ impl PoolInner { let Some(slot_set) = self.slots.get_mut(&name) else { continue; }; - let _ = slot_set.set_wanted_count(slots_wanted_per_backend).await; + let _ = slot_set.set_wanted_count(slots_wanted_per_backend); } let mut new_priority_list = PriorityList::new(); @@ -472,7 +472,7 @@ impl PoolInner { self.priority_list = new_priority_list; } - async fn claim(&mut self, id: ClaimId) -> Result, Error> { + fn claim(&mut self, id: ClaimId) -> Result, Error> { let mut attempted_backend = vec![]; let mut result = Err(Error::NoBackends); @@ -504,7 +504,7 @@ impl PoolInner { // // Either way, put this backend back in the priority list after // we're done with it. - let Ok(claim) = set.claim(id).await else { + let Ok(claim) = set.claim(id) else { event!(Level::DEBUG, "Failed to actually get claim for backend"); rebalancer::claimed_err(&mut weighted_backend); attempted_backend.push(weighted_backend); @@ -525,7 +525,7 @@ impl PoolInner { Err(_) => probes::pool__claim__failed!(|| (self.name.as_str(), id.0)), } - self.priority_list.extend(attempted_backend.into_iter()); + self.priority_list.extend(attempted_backend); result } } diff --git a/src/slot.rs b/src/slot.rs index 6e33f87..9ae845d 100644 --- a/src/slot.rs +++ b/src/slot.rs @@ -14,7 +14,7 @@ use derive_where::derive_where; use std::collections::BTreeMap; use std::sync::{Arc, Mutex}; use thiserror::Error; -use tokio::sync::{mpsc, oneshot, watch, Notify}; +use tokio::sync::{mpsc, watch, Notify}; use tokio::task::JoinHandle; use tokio::time::{interval, Duration}; use tracing::{event, instrument, span, Instrument, Level}; @@ -304,7 +304,13 @@ impl Slot { return true; } Err(err) => { - event!(Level::WARN, pool_name = self.inner.pool_name.as_str(), ?err, ?backend, "Failed to connect"); + event!( + Level::WARN, + pool_name = self.inner.pool_name.as_str(), + ?err, + ?backend, + "Failed to connect" + ); self.inner.failure_window.add(1); retry_duration = retry_duration.exponential_backoff(config.max_connection_backoff); @@ -642,102 +648,238 @@ impl Stats { } } -enum SetRequest { - Claim { - id: ClaimId, - tx: oneshot::Sender, Error>>, - }, - SetWantedCount { - count: usize, - }, -} - -// Owns and runs work on behalf of a [Set]. -struct SetWorker { +// Provides direct access to all underlying slots +// +// Shared by both a [`SetWorker`] and [`Set`] +struct Slots { pool_name: pool::Name, name: backend::Name, backend: Backend, - config: SetConfig, - - wanted_count: usize, - // Interface for actually connecting to backends backend_connector: SharedConnector, + config: SetConfig, - // Interface for receiving client requests - rx: mpsc::Receiver>, + // The actual underlying slots, by ID. + slots: BTreeMap>, + // The desired number of slots + wanted_count: usize, + next_slot_id: SlotId, - // Identifies that the set worker should terminate immediately - terminate_rx: tokio::sync::oneshot::Receiver<()>, + // If "true", new requests are rejected + terminating: bool, // Interface for communicating backend status status_tx: watch::Sender, - - // Sender and receiver for returning old handles. - // - // This is to guarantee a size, and to vend out permits to claim::Handles so they can be sure - // that their connections can return to the set without error. + // Sender for returning old handles. slot_tx: mpsc::Sender>, - slot_rx: mpsc::Receiver>, - - // The actual slots themselves. - slots: BTreeMap>, // Summary information about the health of all slots. // // Must be kept in lockstep with "Self::slots" stats: Arc>, - failure_window: Arc, - - next_slot_id: SlotId, } -impl SetWorker { - #[allow(clippy::too_many_arguments)] - fn new( - pool_name: pool::Name, - set_id: u16, - name: backend::Name, - rx: mpsc::Receiver>, - terminate_rx: tokio::sync::oneshot::Receiver<()>, - status_tx: watch::Sender, - config: SetConfig, - wanted_count: usize, - backend: Backend, - backend_connector: SharedConnector, - stats: Arc>, - failure_window: Arc, - ) -> Self { - let (slot_tx, slot_rx) = mpsc::channel(config.max_count); - let mut set = Self { - pool_name, - name, - backend, - config, - wanted_count, - backend_connector, - stats, - failure_window, - rx, - terminate_rx, - status_tx, - slot_tx, - slot_rx, - slots: BTreeMap::new(), - next_slot_id: SlotId::first(set_id), +impl Slots { + #[instrument( + level = "trace", + skip(self), + err, + name = "Slots::claim", + fields(name = ?self.name), + )] + fn claim(&self, id: ClaimId) -> Result, Error> { + #[cfg(feature = "probes")] + probes::slot__set__claim__start!(|| ( + self.pool_name.as_str(), + id.0, + self.backend.address.to_string() + )); + + // Before we vend out the slot's connection to a client, make sure that + // we have space to take it back once they're done with it. + let Ok(permit) = self.slot_tx.clone().try_reserve_owned() else { + event!(Level::TRACE, "Could not reserve slot_tx permit"); + + #[cfg(feature = "probes")] + probes::slot__set__claim__failed!(|| ( + self.pool_name.as_str(), + id.0, + "Could not reserve slot_tx permit; all slots used" + )); + + // This is more of an "all slots in-use" error, + // but it should look the same to clients. + return Err(Error::NoSlotsReady); }; - set.set_wanted_count(wanted_count); - set + + let Some(handle) = self.take_connected_unclaimed_slot(permit, id) else { + event!(Level::TRACE, "Failed to take unclaimed slot"); + + #[cfg(feature = "probes")] + probes::slot__set__claim__failed!(|| ( + self.pool_name.as_str(), + id.0, + "No unclaimed slots" + )); + return Err(Error::NoSlotsReady); + }; + + #[cfg(feature = "probes")] + probes::slot__set__claim__done!(|| ( + self.pool_name.as_str(), + id.0, + handle.slot_id().as_u64() + )); + + return Ok(handle); } - // Creates a new Slot, which always starts as "Connecting", and spawn a task - // to actually connect to the backend and monitor slot health. + // Borrows a connection out of the first unclaimed slot. + // + // Returns a Handle which has enough context to put the claim back, + // once it's dropped by the client. #[instrument( - skip(self) + skip(self, permit) fields(pool_name = %self.pool_name), - name = "SetWorker::create_slot" + name = "Slots::take_connected_unclaimed_slot" )] + fn take_connected_unclaimed_slot( + &self, + permit: mpsc::OwnedPermit>, + claim_id: ClaimId, + ) -> Option> { + for (id, slot) in &self.slots { + let guarded = slot.inner.guarded.lock().unwrap(); + event!(Level::TRACE, slot_id = id.as_u64(), state = ?guarded.state, "Considering slot"); + if matches!(guarded.state, State::ConnectedUnclaimed(_)) { + event!(Level::TRACE, slot_id = id.as_u64(), "Found unclaimed slot"); + // We intentionally "take the connection out" of the slot and + // "place it into a claim::Handle" in the same method. + // + // This makes it difficult to leak a connection, unless the drop + // method of the claim::Handle is skipped. + let State::ConnectedUnclaimed(DebugIgnore(conn)) = slot.inner.state_transition( + *id, + &self.backend, + guarded, + State::ConnectedClaimed, + ) else { + panic!( + "We just matched this type before replacing it; this should be impossible" + ); + }; + + let borrowed_conn = BorrowedConnection::new(conn, *id); + + #[cfg(feature = "probes")] + crate::probes::handle__claimed!(|| { + ( + self.pool_name.as_str(), + claim_id.0, + borrowed_conn.id.as_u64(), + self.backend.address.to_string(), + ) + }); + + // The "drop" method of the claim::Handle will return it to + // the slot set, through the permit (which is connected to + // slot_rx). + return Some(claim::Handle::new(borrowed_conn, permit)); + } + } + None + } + + fn set_wanted_count(&mut self, count: usize) { + self.wanted_count = std::cmp::min(count, self.config.max_count); + self.conform_slot_count(); + } + + fn conform_slot_count(&mut self) { + let desired = self.wanted_count; + + use std::cmp::Ordering::*; + match desired.cmp(&self.slots.len()) { + Less => { + // Fewer slots wanted. Remove as many as we can. + event!( + Level::TRACE, + current = self.slots.len(), + "Reducing slot count" + ); + + // Gather all the keys we are trying to remove. + // + // If there are many non-removable slots, it's possible + // we don't immediately quiesce to this smaller requested count. + let count_to_remove = self.slots.len() - desired; + let mut to_remove = Vec::with_capacity(count_to_remove); + + // We iterate through all slots twice: + // - First, we try to remove unconnected slots + // - Then we remove any removable slots remaining + // + // This is a minor optimization that avoids tearing down a + // connected spare while also trying to create a new one. + let filters = [ + |slot: &SlotInnerGuarded| { + !slot.state.connected() && slot.state.removable() + }, + |slot: &SlotInnerGuarded| slot.state.removable(), + ]; + for filter in filters { + for (slot_id, slot) in &self.slots { + if to_remove.len() >= count_to_remove { + break; + } + let guarded = slot.inner.guarded.lock().unwrap(); + if filter(&*guarded) { + to_remove.push(*slot_id); + + // It's important that we terminate the slot + // immediately, so the task which manages the slot + // will not continue modifying the state. + slot.inner.state_transition( + *slot_id, + &self.backend, + guarded, + State::Terminated, + ); + } + } + } + + for slot_id in to_remove { + event!(Level::TRACE, slot_id = slot_id.as_u64(), "Removing slot"); + let Some(slot) = self.slots.remove(&slot_id) else { + continue; + }; + let Some(handle) = slot.inner.guarded.lock().unwrap().handle.take() else { + continue; + }; + event!(Level::TRACE, slot_id = slot_id.as_u64(), "Aborting task"); + handle.abort(); + } + } + Greater => { + // More slots wanted. This case is easy, we can always fill + // in "connecting" slots immediately. + event!( + Level::TRACE, + current = self.slots.len(), + "Increasing slot count" + ); + let new_slots = (desired - self.slots.len()) as u64; + for slot_id in self.next_slot_id.0..self.next_slot_id.0 + new_slots { + self.create_slot(SlotId(slot_id)); + } + self.next_slot_id.0 += new_slots; + } + Equal => {} + } + } + fn create_slot(&mut self, slot_id: SlotId) { let (terminate_tx, mut terminate_rx) = tokio::sync::oneshot::channel(); let slot = Slot { @@ -841,7 +983,11 @@ impl SetWorker { _ = &mut terminate_rx => { // If we've been instructed to bail out, // do that immediately. - event!(Level::TRACE, slot_id = slot_id.as_u64(), "Terminating while monitoring"); + event!( + Level::TRACE, + slot_id = slot_id.as_u64(), + "Terminating while monitoring" + ); return; }, _ = interval.tick() => { @@ -870,61 +1016,68 @@ impl SetWorker { } })); } +} - // Borrows a connection out of the first unclaimed slot. - // - // Returns a Handle which has enough context to put the claim back, - // once it's dropped by the client. - #[instrument( - skip(self, permit) - fields(pool_name = %self.pool_name), - name = "SetWorker::take_connected_unclaimed_slot" - )] - fn take_connected_unclaimed_slot( - &mut self, - permit: mpsc::OwnedPermit>, - claim_id: ClaimId, - ) -> Option> { - for (id, slot) in &mut self.slots { - let guarded = slot.inner.guarded.lock().unwrap(); - event!(Level::TRACE, slot_id = id.as_u64(), state = ?guarded.state, "Considering slot"); - if matches!(guarded.state, State::ConnectedUnclaimed(_)) { - event!(Level::TRACE, slot_id = id.as_u64(), "Found unclaimed slot"); - // We intentionally "take the connection out" of the slot and - // "place it into a claim::Handle" in the same method. - // - // This makes it difficult to leak a connection, unless the drop - // method of the claim::Handle is skipped. - let State::ConnectedUnclaimed(DebugIgnore(conn)) = slot.inner.state_transition( - *id, - &self.backend, - guarded, - State::ConnectedClaimed, - ) else { - panic!( - "We just matched this type before replacing it; this should be impossible" - ); - }; - - let borrowed_conn = BorrowedConnection::new(conn, *id); +// Owns and runs work on behalf of a [Set]. +struct SetWorker { + pool_name: pool::Name, + name: backend::Name, + backend: Backend, - #[cfg(feature = "probes")] - crate::probes::handle__claimed!(|| { - ( - self.pool_name.as_str(), - claim_id.0, - borrowed_conn.id.as_u64(), - self.backend.address.to_string(), - ) - }); + // Identifies that the set worker should terminate immediately + terminate_rx: tokio::sync::oneshot::Receiver<()>, - // The "drop" method of the claim::Handle will return it to - // the slot set, through the permit (which is connected to - // slot_rx). - return Some(claim::Handle::new(borrowed_conn, permit)); - } - } - None + // Sender and receiver for returning old handles. + // + // This is to guarantee a size, and to vend out permits to claim::Handles so + // they can be sure that their connections can return to the set without + // error. + slot_rx: mpsc::Receiver>, + + // The actual slots themselves. + slots: Arc>>, +} + +impl SetWorker { + #[allow(clippy::too_many_arguments)] + fn new( + pool_name: pool::Name, + set_id: u16, + name: backend::Name, + terminate_rx: tokio::sync::oneshot::Receiver<()>, + status_tx: watch::Sender, + config: SetConfig, + wanted_count: usize, + backend: Backend, + backend_connector: SharedConnector, + stats: Arc>, + failure_window: Arc, + ) -> Self { + let (slot_tx, slot_rx) = mpsc::channel(config.max_count); + let set = Self { + pool_name: pool_name.clone(), + name: name.clone(), + backend: backend.clone(), + terminate_rx, + slot_rx, + slots: Arc::new(Mutex::new(Slots { + pool_name, + name, + backend, + backend_connector, + config, + slots: BTreeMap::new(), + wanted_count, + next_slot_id: SlotId::first(set_id), + terminating: false, + status_tx, + slot_tx, + stats, + failure_window, + })), + }; + set.set_wanted_count(wanted_count); + set } // Takes back borrowed slots from clients who dropped their claim handles. @@ -937,11 +1090,13 @@ impl SetWorker { ), name = "SetWorker::recycle_connection" )] - fn recycle_connection(&mut self, borrowed_conn: BorrowedConnection) { + fn recycle_connection(&self, borrowed_conn: BorrowedConnection) { let slot_id = borrowed_conn.id; #[cfg(feature = "probes")] crate::probes::handle__returned!(|| (self.pool_name.as_str(), slot_id.as_u64())); - let inner = self + + let mut slots = self.slots.lock().unwrap(); + let inner = slots .slots .get_mut(&slot_id) .expect( @@ -969,178 +1124,14 @@ impl SetWorker { // If we tried to shrink the slot count while too many connections were // in-use, it's possible there's more work to do. Try to conform the // slot count after recycling each connection. - self.conform_slot_count(); + slots.conform_slot_count(); inner.recycling_needed.notify_one(); } - fn set_wanted_count(&mut self, count: usize) { - self.wanted_count = std::cmp::min(count, self.config.max_count); - self.conform_slot_count(); - } - - // Makes the number of slots as close to "desired_count" as we can get. - #[instrument( - level = "trace", - skip(self), - fields( - wanted_count = self.wanted_count, - name = ?self.name, - ), - name = "SetWorker::conform_slot_count" - )] - fn conform_slot_count(&mut self) { - let desired = self.wanted_count; - - use std::cmp::Ordering::*; - match desired.cmp(&self.slots.len()) { - Less => { - // Fewer slots wanted. Remove as many as we can. - event!( - Level::TRACE, - current = self.slots.len(), - "Reducing slot count" - ); - - // Gather all the keys we are trying to remove. - // - // If there are many non-removable slots, it's possible - // we don't immediately quiesce to this smaller requested count. - let count_to_remove = self.slots.len() - desired; - let mut to_remove = Vec::with_capacity(count_to_remove); - - // We iterate through all slots twice: - // - First, we try to remove unconnected slots - // - Then we remove any removable slots remaining - // - // This is a minor optimization that avoids tearing down a - // connected spare while also trying to create a new one. - let filters = [ - |slot: &SlotInnerGuarded| { - !slot.state.connected() && slot.state.removable() - }, - |slot: &SlotInnerGuarded| slot.state.removable(), - ]; - for filter in filters { - for (slot_id, slot) in &self.slots { - if to_remove.len() >= count_to_remove { - break; - } - let guarded = slot.inner.guarded.lock().unwrap(); - if filter(&*guarded) { - to_remove.push(*slot_id); - - // It's important that we terminate the slot - // immediately, so the task which manages the slot - // will not continue modifying the state. - slot.inner.state_transition( - *slot_id, - &self.backend, - guarded, - State::Terminated, - ); - } - } - } - - for slot_id in to_remove { - event!(Level::TRACE, slot_id = slot_id.as_u64(), "Removing slot"); - let Some(slot) = self.slots.remove(&slot_id) else { - continue; - }; - let Some(handle) = slot.inner.guarded.lock().unwrap().handle.take() else { - continue; - }; - event!(Level::TRACE, slot_id = slot_id.as_u64(), "Aborting task"); - handle.abort(); - } - } - Greater => { - // More slots wanted. This case is easy, we can always fill - // in "connecting" slots immediately. - event!( - Level::TRACE, - current = self.slots.len(), - "Increasing slot count" - ); - let new_slots = (desired - self.slots.len()) as u64; - for slot_id in self.next_slot_id.0..self.next_slot_id.0 + new_slots { - self.create_slot(SlotId(slot_id)); - } - self.next_slot_id.0 += new_slots; - } - Equal => {} - } - } - - #[instrument( - level = "trace", - skip(self), - err, - name = "SetWorker::claim", - fields(name = ?self.name), - )] - fn claim(&mut self, id: ClaimId) -> Result, Error> { - #[cfg(feature = "probes")] - probes::slot__set__claim__start!(|| ( - self.pool_name.as_str(), - id.0, - self.backend.address.to_string() - )); - - // Before we vend out the slot's connection to a client, make sure that - // we have space to take it back once they're done with it. - let Ok(permit) = self.slot_tx.clone().try_reserve_owned() else { - event!(Level::TRACE, "Could not reserve slot_tx permit"); - - #[cfg(feature = "probes")] - probes::slot__set__claim__failed!(|| ( - self.pool_name.as_str(), - id.0, - "Could not reserve slot_tx permit; all slots used" - )); - - // This is more of an "all slots in-use" error, - // but it should look the same to clients. - return Err(Error::NoSlotsReady); - }; - - let Some(handle) = self.take_connected_unclaimed_slot(permit, id) else { - event!(Level::TRACE, "Failed to take unclaimed slot"); - - #[cfg(feature = "probes")] - probes::slot__set__claim__failed!(|| ( - self.pool_name.as_str(), - id.0, - "No unclaimed slots" - )); - return Err(Error::NoSlotsReady); - }; - - #[cfg(feature = "probes")] - probes::slot__set__claim__done!(|| ( - self.pool_name.as_str(), - id.0, - handle.slot_id().as_u64() - )); - - return Ok(handle); - } - - // Note that this function is not asynchronous. - // - // This is intentional: We should not be await-ing in the SetWorker - // task when servicing client requests. - fn handle_client_request(&mut self, request: SetRequest) { - match request { - SetRequest::Claim { id, tx } => { - let result = self.claim(id); - let _ = tx.send(result); - } - SetRequest::SetWantedCount { count } => { - self.set_wanted_count(count); - } - } + fn set_wanted_count(&self, count: usize) { + let mut slots = self.slots.lock().unwrap(); + slots.set_wanted_count(count); } #[instrument( @@ -1168,22 +1159,16 @@ impl SetWorker { }, } }, - // Handle requests from clients - request = self.rx.recv() => { - if let Some(request) = request { - self.handle_client_request(request); - } else { - // All clients have gone away, so terminate the set. - // Break out of the loop rather than return, so that the - // termination code runs. - break; - } - } } } // If we have exited from the run loop, tear down the background tasks - while let Some((_id, slot)) = self.slots.pop_first() { + let mut slots = { + let mut slots_lock = self.slots.lock().unwrap(); + slots_lock.terminating = true; + std::mem::take(&mut slots_lock.slots) + }; + while let Some((_id, slot)) = slots.pop_first() { let handle = { let mut lock = slot.inner.guarded.lock().unwrap(); @@ -1221,7 +1206,7 @@ pub(crate) enum SetState { /// A set of slots for a particular backend. pub(crate) struct Set { - tx: mpsc::Sender>, + slots: Arc>>, status_rx: watch::Receiver, @@ -1249,37 +1234,35 @@ impl Set { backend: Backend, backend_connector: SharedConnector, ) -> Self { - let (tx, rx) = mpsc::channel(1); let (terminate_tx, terminate_rx) = tokio::sync::oneshot::channel(); let (status_tx, status_rx) = watch::channel(SetState::Offline); let failure_duration = config.max_connection_backoff * 2; let stats = Arc::new(Mutex::new(Stats::default())); let failure_window = Arc::new(WindowedCounter::new(failure_duration)); + + let mut worker = SetWorker::new( + pool_name, + set_id, + name.clone(), + terminate_rx, + status_tx, + config, + wanted_count, + backend, + backend_connector, + stats.clone(), + failure_window.clone(), + ); + let slots = worker.slots.clone(); + let handle = tokio::task::spawn({ - let stats = stats.clone(); - let failure_window = failure_window.clone(); - let name = name.clone(); async move { - let mut worker = SetWorker::new( - pool_name, - set_id, - name, - rx, - terminate_rx, - status_tx, - config, - wanted_count, - backend, - backend_connector, - stats, - failure_window, - ); worker.run().await; } }); Self { - tx, + slots, status_rx, name, stats, @@ -1321,15 +1304,12 @@ impl Set { name = "Set::claim", fields(name = ?self.name), )] - pub(crate) async fn claim(&mut self, id: ClaimId) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - - self.tx - .send(SetRequest::Claim { id, tx }) - .await - .map_err(|_| Error::SlotWorkerTerminated)?; - - rx.await.map_err(|_| Error::SlotWorkerTerminated)? + pub(crate) fn claim(&self, id: ClaimId) -> Result, Error> { + let slots = self.slots.lock().unwrap(); + if slots.terminating { + return Err(Error::SlotWorkerTerminated); + } + slots.claim(id) } /// Updates the number of "wanted" slots within the slot set. @@ -1342,11 +1322,12 @@ impl Set { name = "Set::set_wanted_count", fields(name = ?self.name), )] - pub(crate) async fn set_wanted_count(&mut self, count: usize) -> Result<(), Error> { - self.tx - .send(SetRequest::SetWantedCount { count }) - .await - .map_err(|_| Error::SlotWorkerTerminated)?; + pub(crate) fn set_wanted_count(&self, count: usize) -> Result<(), Error> { + let mut slots = self.slots.lock().unwrap(); + if slots.terminating { + return Err(Error::SlotWorkerTerminated); + } + slots.set_wanted_count(count); Ok(()) } @@ -1531,7 +1512,7 @@ mod test { #[tokio::test] async fn test_one_claim() { setup_tracing_subscriber(); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig::default(), @@ -1547,13 +1528,13 @@ mod test { .await .unwrap(); - let _conn = set.claim(ClaimId::new()).await.unwrap(); + let _conn = set.claim(ClaimId::new()).unwrap(); } #[tokio::test] async fn test_drain_slots() { setup_tracing_subscriber(); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig::default(), @@ -1570,8 +1551,8 @@ mod test { .unwrap(); // Grab a connection, then set the "Wanted" count to zero. - let conn = set.claim(ClaimId::new()).await.unwrap(); - set.set_wanted_count(0).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); + set.set_wanted_count(0).unwrap(); // Let the connections drain loop { @@ -1598,7 +1579,7 @@ mod test { #[tokio::test] async fn test_no_slots_add_some_later() { setup_tracing_subscriber(); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig::default(), @@ -1610,13 +1591,12 @@ mod test { // We start with nothing available set.claim(ClaimId::new()) - .await .map(|_| ()) .expect_err("Should not be able to get claims yet"); assert_eq!(set.get_state(), SetState::Offline); // We can later adjust the count of desired slots - set.set_wanted_count(3).await.unwrap(); + set.set_wanted_count(3).unwrap(); // Let the connections fill up set.monitor() @@ -1625,13 +1605,13 @@ mod test { .unwrap(); // When this completes, the connections may be claimed - let _conn = set.claim(ClaimId::new()).await.unwrap(); + let _conn = set.claim(ClaimId::new()).unwrap(); } #[tokio::test] async fn test_all_claims() { setup_tracing_subscriber(); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig::default(), @@ -1652,12 +1632,11 @@ mod test { } } - let _conn1 = set.claim(ClaimId::new()).await.unwrap(); - let _conn2 = set.claim(ClaimId::new()).await.unwrap(); - let conn3 = set.claim(ClaimId::new()).await.unwrap(); + let _conn1 = set.claim(ClaimId::new()).unwrap(); + let _conn2 = set.claim(ClaimId::new()).unwrap(); + let conn3 = set.claim(ClaimId::new()).unwrap(); set.claim(ClaimId::new()) - .await .map(|_| ()) .expect_err("We should fail to acquire a 4th claim from 3 slot set"); @@ -1675,7 +1654,7 @@ mod test { } } - let _conn4 = set.claim(ClaimId::new()).await.unwrap(); + let _conn4 = set.claim(ClaimId::new()).unwrap(); } #[tokio::test] @@ -1684,7 +1663,7 @@ mod test { let connector = Arc::new(TestConnector::new()); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig { @@ -1728,14 +1707,13 @@ mod test { )); // Grab three connections - let _claim1 = set.claim(ClaimId::new()).await.expect("Failed to claim"); - let _claim2 = set.claim(ClaimId::new()).await.expect("Failed to claim"); - let claim3 = set.claim(ClaimId::new()).await.expect("Failed to claim"); + let _claim1 = set.claim(ClaimId::new()).expect("Failed to claim"); + let _claim2 = set.claim(ClaimId::new()).expect("Failed to claim"); + let claim3 = set.claim(ClaimId::new()).expect("Failed to claim"); // Cannot claim the fourth connection, this slot set is all used. assert!(matches!( set.claim(ClaimId::new()) - .await .map(|_| ()) .expect_err("Should have reached claim capacity"), Error::NoSlotsReady, @@ -1842,7 +1820,7 @@ mod test { let wanted_count = 5; let connector = Arc::new(TestConnector::new()); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig { @@ -1869,7 +1847,7 @@ mod test { } // Grab one of the slots. Inspect the state, validating it is connected. - let conn = set.claim(ClaimId::new()).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); let raw_conn = conn.clone(); assert_eq!(raw_conn.get_state(), TestConnectionState::Connected); drop(conn); @@ -1891,7 +1869,7 @@ mod test { assert_eq!(raw_conn.get_state(), TestConnectionState::Recycled); connector.set_recyclable(false); - let conn = set.claim(ClaimId::new()).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); let raw_conn = conn.clone(); assert_eq!(raw_conn.get_state(), TestConnectionState::Recycled); drop(conn); @@ -1923,7 +1901,7 @@ mod test { let wanted_count = 5; let connector = Arc::new(TestConnector::new()); let health_interval = Duration::from_millis(1); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig { @@ -1946,7 +1924,7 @@ mod test { // // This means no new connections, and existing connections will die // when health checked. - let conn = set.claim(ClaimId::new()).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); connector.set_connectable(false); connector.set_valid(false); let raw_conn = conn.clone(); @@ -1991,19 +1969,21 @@ mod test { .await .unwrap(); - let conn = set.claim(ClaimId::new()).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); // We should be able to terminate, even with a claim out. set.terminate().await; - assert!(matches!( - set.claim(ClaimId::new()).await.map(|_| ()).unwrap_err(), - Error::SlotWorkerTerminated, - )); - assert!(matches!( - set.set_wanted_count(1).await.unwrap_err(), - Error::SlotWorkerTerminated - )); + let err = set.claim(ClaimId::new()).map(|_| ()).unwrap_err(); + assert!( + matches!(err, Error::SlotWorkerTerminated,), + "Unexpected error: {err}" + ); + let err = set.set_wanted_count(1).unwrap_err(); + assert!( + matches!(err, Error::SlotWorkerTerminated), + "Unexpected error: {err}" + ); drop(conn); } @@ -2106,14 +2086,14 @@ mod test { }) .await .unwrap(); - handles.push(set.claim(ClaimId::new()).await.unwrap()); + handles.push(set.claim(ClaimId::new()).unwrap()); } // All future connections should be slow! connector.stall(); // This should start making new connections... - set.set_wanted_count(config.new_wanted).await.unwrap(); + set.set_wanted_count(config.new_wanted).unwrap(); set.terminate().await; @@ -2121,10 +2101,11 @@ mod test { drop(handles); - assert!(matches!( - set.claim(ClaimId::new()).await.map(|_| ()).unwrap_err(), - Error::SlotWorkerTerminated, - )); + let err = set.claim(ClaimId::new()).map(|_| ()).unwrap_err(); + assert!( + matches!(err, Error::SlotWorkerTerminated,), + "Unexpected err: {err}" + ); } } }