diff --git a/vortex-cuda/src/dynamic_dispatch/mod.rs b/vortex-cuda/src/dynamic_dispatch/mod.rs index 18bc2ff63b7..99b191e52b6 100644 --- a/vortex-cuda/src/dynamic_dispatch/mod.rs +++ b/vortex-cuda/src/dynamic_dispatch/mod.rs @@ -803,6 +803,39 @@ mod tests { Ok(()) } + #[crate::test] + fn test_dict_mismatched_ptypes_rejected() -> VortexResult<()> { + let dict_values: Vec = vec![100, 200, 300, 400]; + let len = 3000; + let codes: Vec = (0..len).map(|i| (i % dict_values.len()) as u8).collect(); + + let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable); + let values_prim = PrimitiveArray::new(Buffer::from(dict_values), NonNullable); + let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?; + + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + // build_plan should fail because u8 codes != u32 values in byte width. + assert!(build_plan(&dict.into_array(), &cuda_ctx).is_err()); + + Ok(()) + } + + #[crate::test] + fn test_runend_mismatched_ptypes_rejected() -> VortexResult<()> { + let ends: Vec = vec![1000, 2000, 3000]; + let values: Vec = vec![10, 20, 30]; + + let ends_arr = PrimitiveArray::new(Buffer::from(ends), NonNullable).into_array(); + let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array(); + let re = RunEndArray::new(ends_arr, values_arr); + + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + // build_plan should fail because u64 ends != i32 values in byte width. + assert!(build_plan(&re.into_array(), &cuda_ctx).is_err()); + + Ok(()) + } + #[rstest] #[case(0, 1024)] #[case(0, 3000)] diff --git a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs index 81fb6497dbb..b25717476d6 100644 --- a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs +++ b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs @@ -178,10 +178,36 @@ fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool { } return false; } + if id == Dict::ID { + if let Ok(a) = array.clone().try_into::() { + // As of now the dict dyn dispatch kernel requires + // codes and values to have the same byte width. + return match ( + PType::try_from(a.values().dtype()), + PType::try_from(a.codes().dtype()), + ) { + (Ok(values), Ok(codes)) => values.byte_width() == codes.byte_width(), + _ => false, + }; + } + return false; + } + if id == RunEnd::ID { + if let Ok(a) = array.clone().try_into::() { + // As of now the run-end dyn dispatch kernel requires + // ends and values to have the same byte width. + return match ( + PType::try_from(a.ends().dtype()), + PType::try_from(a.values().dtype()), + ) { + (Ok(e), Ok(v)) => e.byte_width() == v.byte_width(), + _ => false, + }; + } + return false; + } id == FoR::ID || id == ZigZag::ID - || id == Dict::ID - || id == RunEnd::ID || id == Primitive::ID || id == Slice::ID || id == Sequence::ID @@ -264,6 +290,13 @@ impl PlanBuilderState<'_> { return Ok(pipeline); } + if !is_dyn_dispatch_compatible(&array) { + vortex_bail!( + "Encoding {:?} is not compatible with the dynamic dispatch plan builder", + array.encoding_id() + ); + } + let id = array.encoding_id(); if id == BitPacked::ID {