diff --git a/crates/openshell-server/src/grpc.rs b/crates/openshell-server/src/grpc.rs index a2c6a58f..268a692b 100644 --- a/crates/openshell-server/src/grpc.rs +++ b/crates/openshell-server/src/grpc.rs @@ -32,7 +32,7 @@ use openshell_core::proto::{ WatchSandboxRequest, open_shell_server::OpenShell, }; use openshell_core::proto::{ - Sandbox, SandboxPhase, SandboxPolicy as ProtoSandboxPolicy, SandboxTemplate, + InferenceRoute, Sandbox, SandboxPhase, SandboxPolicy as ProtoSandboxPolicy, SandboxTemplate, }; use prost::Message; use sha2::{Digest, Sha256}; @@ -3290,6 +3290,73 @@ async fn delete_provider_record( return Err(Status::invalid_argument("name is required")); } + // Early-out: if the provider doesn't exist, nothing to delete. + let exists = store + .get_by_name(Provider::object_type(), name) + .await + .map_err(|e| Status::internal(format!("check provider failed: {e}")))?; + if exists.is_none() { + return Ok(false); + } + + // Check if any sandbox references this provider. + let sandbox_records = store + .list(Sandbox::object_type(), u32::MAX, 0) + .await + .map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?; + + let mut referencing_sandboxes: Vec = Vec::new(); + for record in sandbox_records { + let sandbox = Sandbox::decode(record.payload.as_slice()) + .map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?; + // Skip sandboxes that are already being deleted. + if sandbox.phase == SandboxPhase::Deleting as i32 { + continue; + } + if let Some(spec) = &sandbox.spec { + if spec.providers.iter().any(|p| p == name) { + referencing_sandboxes.push(sandbox.name.clone()); + } + } + } + + // Check if any inference route references this provider. + let mut referencing_routes: Vec = Vec::new(); + for route_name in ["inference.local", "sandbox-system"] { + if let Some(route) = store + .get_message_by_name::(route_name) + .await + .map_err(|e| Status::internal(format!("fetch inference route failed: {e}")))? + { + if route + .config + .as_ref() + .is_some_and(|c| c.provider_name == name) + { + referencing_routes.push(route_name.to_string()); + } + } + } + + if !referencing_sandboxes.is_empty() || !referencing_routes.is_empty() { + let mut details = Vec::new(); + if !referencing_sandboxes.is_empty() { + details.push(format!("sandbox(es): {}", referencing_sandboxes.join(", "))); + } + if !referencing_routes.is_empty() { + details.push(format!( + "inference route(s): {}", + referencing_routes.join(", ") + )); + } + return Err(Status::failed_precondition(format!( + "cannot delete provider '{}': still referenced by {}. \ + Remove or update these references before deleting the provider.", + name, + details.join(" and ") + ))); + } + store .delete_by_name(Provider::object_type(), name) .await @@ -3326,7 +3393,10 @@ mod tests { validate_provider_fields, validate_sandbox_spec, }; use crate::persistence::Store; - use openshell_core::proto::{Provider, SandboxSpec, SandboxTemplate}; + use openshell_core::proto::{ + ClusterInferenceConfig, InferenceRoute, Provider, Sandbox, SandboxPhase, SandboxSpec, + SandboxTemplate, + }; use std::collections::HashMap; use tonic::Code; @@ -4624,4 +4694,129 @@ mod tests { assert_eq!(err.code(), Code::InvalidArgument); assert!(err.message().contains("value")); } + + #[tokio::test] + async fn delete_provider_blocked_by_sandbox_reference() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let provider = + create_provider_record(&store, provider_with_values("my-provider", "nvidia")) + .await + .unwrap(); + + // Create a sandbox that references the provider. + let sandbox = Sandbox { + id: uuid::Uuid::new_v4().to_string(), + name: "test-sandbox".to_string(), + spec: Some(SandboxSpec { + providers: vec!["my-provider".to_string()], + ..SandboxSpec::default() + }), + ..Sandbox::default() + }; + store.put_message(&sandbox).await.unwrap(); + + // Deleting the referenced provider should fail. + let err = delete_provider_record(&store, &provider.name) + .await + .unwrap_err(); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("test-sandbox")); + assert!(err.message().contains("my-provider")); + } + + #[tokio::test] + async fn delete_provider_blocked_by_inference_route() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let provider = + create_provider_record(&store, provider_with_values("inference-provider", "nvidia")) + .await + .unwrap(); + + // Create an inference route that references the provider. + let route = InferenceRoute { + id: uuid::Uuid::new_v4().to_string(), + name: "inference.local".to_string(), + config: Some(ClusterInferenceConfig { + provider_name: "inference-provider".to_string(), + ..ClusterInferenceConfig::default() + }), + version: 1, + }; + store.put_message(&route).await.unwrap(); + + // Deleting the referenced provider should fail. + let err = delete_provider_record(&store, &provider.name) + .await + .unwrap_err(); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("inference.local")); + assert!(err.message().contains("inference-provider")); + } + + #[tokio::test] + async fn delete_provider_succeeds_when_unreferenced() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let provider = + create_provider_record(&store, provider_with_values("lonely-provider", "nvidia")) + .await + .unwrap(); + + // Create a sandbox that references a DIFFERENT provider name. + let sandbox = Sandbox { + id: uuid::Uuid::new_v4().to_string(), + name: "other-sandbox".to_string(), + spec: Some(SandboxSpec { + providers: vec!["other-provider".to_string()], + ..SandboxSpec::default() + }), + ..Sandbox::default() + }; + store.put_message(&sandbox).await.unwrap(); + + // Deleting the unreferenced provider should succeed. + let deleted = delete_provider_record(&store, &provider.name) + .await + .unwrap(); + assert!(deleted); + } + + #[tokio::test] + async fn delete_provider_ignores_deleting_sandboxes() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let provider = + create_provider_record(&store, provider_with_values("my-provider", "nvidia")) + .await + .unwrap(); + + // Create a sandbox that references the provider but is already being deleted. + let sandbox = Sandbox { + id: uuid::Uuid::new_v4().to_string(), + name: "dying-sandbox".to_string(), + phase: SandboxPhase::Deleting as i32, + spec: Some(SandboxSpec { + providers: vec!["my-provider".to_string()], + ..SandboxSpec::default() + }), + ..Sandbox::default() + }; + store.put_message(&sandbox).await.unwrap(); + + // Deleting sandboxes should not block provider deletion. + let deleted = delete_provider_record(&store, &provider.name) + .await + .unwrap(); + assert!(deleted); + } }