Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 197 additions & 2 deletions crates/openshell-server/src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use openshell_core::proto::{
WatchSandboxRequest, open_shell_server::OpenShell,
};
use openshell_core::proto::{
Sandbox, SandboxPhase, SandboxPolicy as ProtoSandboxPolicy, SandboxTemplate,
InferenceRoute, Sandbox, SandboxPhase, SandboxPolicy as ProtoSandboxPolicy, SandboxTemplate,
};
use prost::Message;
use sha2::{Digest, Sha256};
Expand Down Expand Up @@ -3290,6 +3290,73 @@ async fn delete_provider_record(
return Err(Status::invalid_argument("name is required"));
}

// Early-out: if the provider doesn't exist, nothing to delete.
let exists = store
.get_by_name(Provider::object_type(), name)
.await
.map_err(|e| Status::internal(format!("check provider failed: {e}")))?;
if exists.is_none() {
return Ok(false);
}

// Check if any sandbox references this provider.
let sandbox_records = store
.list(Sandbox::object_type(), u32::MAX, 0)
.await
.map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?;

let mut referencing_sandboxes: Vec<String> = Vec::new();
for record in sandbox_records {
let sandbox = Sandbox::decode(record.payload.as_slice())
.map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?;
// Skip sandboxes that are already being deleted.
if sandbox.phase == SandboxPhase::Deleting as i32 {
continue;
}
if let Some(spec) = &sandbox.spec {
if spec.providers.iter().any(|p| p == name) {
referencing_sandboxes.push(sandbox.name.clone());
}
}
}

// Check if any inference route references this provider.
let mut referencing_routes: Vec<String> = Vec::new();
for route_name in ["inference.local", "sandbox-system"] {
if let Some(route) = store
.get_message_by_name::<InferenceRoute>(route_name)
.await
.map_err(|e| Status::internal(format!("fetch inference route failed: {e}")))?
{
if route
.config
.as_ref()
.is_some_and(|c| c.provider_name == name)
{
referencing_routes.push(route_name.to_string());
}
}
}

if !referencing_sandboxes.is_empty() || !referencing_routes.is_empty() {
let mut details = Vec::new();
if !referencing_sandboxes.is_empty() {
details.push(format!("sandbox(es): {}", referencing_sandboxes.join(", ")));
}
if !referencing_routes.is_empty() {
details.push(format!(
"inference route(s): {}",
referencing_routes.join(", ")
));
}
return Err(Status::failed_precondition(format!(
"cannot delete provider '{}': still referenced by {}. \
Remove or update these references before deleting the provider.",
name,
details.join(" and ")
)));
}

store
.delete_by_name(Provider::object_type(), name)
.await
Expand Down Expand Up @@ -3326,7 +3393,10 @@ mod tests {
validate_provider_fields, validate_sandbox_spec,
};
use crate::persistence::Store;
use openshell_core::proto::{Provider, SandboxSpec, SandboxTemplate};
use openshell_core::proto::{
ClusterInferenceConfig, InferenceRoute, Provider, Sandbox, SandboxPhase, SandboxSpec,
SandboxTemplate,
};
use std::collections::HashMap;
use tonic::Code;

Expand Down Expand Up @@ -4624,4 +4694,129 @@ mod tests {
assert_eq!(err.code(), Code::InvalidArgument);
assert!(err.message().contains("value"));
}

#[tokio::test]
async fn delete_provider_blocked_by_sandbox_reference() {
let store = Store::connect("sqlite::memory:?cache=shared")
.await
.unwrap();

let provider =
create_provider_record(&store, provider_with_values("my-provider", "nvidia"))
.await
.unwrap();

// Create a sandbox that references the provider.
let sandbox = Sandbox {
id: uuid::Uuid::new_v4().to_string(),
name: "test-sandbox".to_string(),
spec: Some(SandboxSpec {
providers: vec!["my-provider".to_string()],
..SandboxSpec::default()
}),
..Sandbox::default()
};
store.put_message(&sandbox).await.unwrap();

// Deleting the referenced provider should fail.
let err = delete_provider_record(&store, &provider.name)
.await
.unwrap_err();
assert_eq!(err.code(), Code::FailedPrecondition);
assert!(err.message().contains("test-sandbox"));
assert!(err.message().contains("my-provider"));
}

#[tokio::test]
async fn delete_provider_blocked_by_inference_route() {
let store = Store::connect("sqlite::memory:?cache=shared")
.await
.unwrap();

let provider =
create_provider_record(&store, provider_with_values("inference-provider", "nvidia"))
.await
.unwrap();

// Create an inference route that references the provider.
let route = InferenceRoute {
id: uuid::Uuid::new_v4().to_string(),
name: "inference.local".to_string(),
config: Some(ClusterInferenceConfig {
provider_name: "inference-provider".to_string(),
..ClusterInferenceConfig::default()
}),
version: 1,
};
store.put_message(&route).await.unwrap();

// Deleting the referenced provider should fail.
let err = delete_provider_record(&store, &provider.name)
.await
.unwrap_err();
assert_eq!(err.code(), Code::FailedPrecondition);
assert!(err.message().contains("inference.local"));
assert!(err.message().contains("inference-provider"));
}

#[tokio::test]
async fn delete_provider_succeeds_when_unreferenced() {
let store = Store::connect("sqlite::memory:?cache=shared")
.await
.unwrap();

let provider =
create_provider_record(&store, provider_with_values("lonely-provider", "nvidia"))
.await
.unwrap();

// Create a sandbox that references a DIFFERENT provider name.
let sandbox = Sandbox {
id: uuid::Uuid::new_v4().to_string(),
name: "other-sandbox".to_string(),
spec: Some(SandboxSpec {
providers: vec!["other-provider".to_string()],
..SandboxSpec::default()
}),
..Sandbox::default()
};
store.put_message(&sandbox).await.unwrap();

// Deleting the unreferenced provider should succeed.
let deleted = delete_provider_record(&store, &provider.name)
.await
.unwrap();
assert!(deleted);
}

#[tokio::test]
async fn delete_provider_ignores_deleting_sandboxes() {
let store = Store::connect("sqlite::memory:?cache=shared")
.await
.unwrap();

let provider =
create_provider_record(&store, provider_with_values("my-provider", "nvidia"))
.await
.unwrap();

// Create a sandbox that references the provider but is already being deleted.
let sandbox = Sandbox {
id: uuid::Uuid::new_v4().to_string(),
name: "dying-sandbox".to_string(),
phase: SandboxPhase::Deleting as i32,
spec: Some(SandboxSpec {
providers: vec!["my-provider".to_string()],
..SandboxSpec::default()
}),
..Sandbox::default()
};
store.put_message(&sandbox).await.unwrap();

// Deleting sandboxes should not block provider deletion.
let deleted = delete_provider_record(&store, &provider.name)
.await
.unwrap();
assert!(deleted);
}
}
Loading