diff --git a/datafusion/spark/src/function/bitwise/bitwise_not.rs b/datafusion/spark/src/function/bitwise/bitwise_not.rs index 5f8cf36911f43..e7285d4804950 100644 --- a/datafusion/spark/src/function/bitwise/bitwise_not.rs +++ b/datafusion/spark/src/function/bitwise/bitwise_not.rs @@ -73,25 +73,11 @@ impl ScalarUDFImpl for SparkBitwiseNot { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - if args.arg_fields.len() != 1 { - return plan_err!("bitwise_not expects exactly 1 argument"); - } - - let input_field = &args.arg_fields[0]; - - let out_dt = input_field.data_type().clone(); - let mut out_nullable = input_field.is_nullable(); - - let scalar_null_present = args - .scalar_arguments - .iter() - .any(|opt_s| opt_s.is_some_and(|sv| sv.is_null())); - - if scalar_null_present { - out_nullable = true; - } - - Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable))) + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + args.arg_fields[0].is_nullable(), + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -196,32 +182,4 @@ mod tests { assert!(out_i64_null.is_nullable()); assert_eq!(out_i64_null.data_type(), &DataType::Int64); } - - #[test] - fn test_bitwise_not_nullability_with_null_scalar() -> Result<()> { - use arrow::datatypes::{DataType, Field}; - use datafusion_common::ScalarValue; - use std::sync::Arc; - - let func = SparkBitwiseNot::new(); - - let non_nullable: FieldRef = Arc::new(Field::new("col", DataType::Int32, false)); - - let out = func.return_field_from_args(ReturnFieldArgs { - arg_fields: &[Arc::clone(&non_nullable)], - scalar_arguments: &[None], - })?; - assert!(!out.is_nullable()); - assert_eq!(out.data_type(), &DataType::Int32); - - let null_scalar = ScalarValue::Int32(None); - let out_with_null_scalar = func.return_field_from_args(ReturnFieldArgs { - arg_fields: &[Arc::clone(&non_nullable)], - scalar_arguments: &[Some(&null_scalar)], - })?; - assert!(out_with_null_scalar.is_nullable()); - assert_eq!(out_with_null_scalar.data_type(), &DataType::Int32); - - Ok(()) - } } diff --git a/datafusion/spark/src/function/datetime/date_add.rs b/datafusion/spark/src/function/datetime/date_add.rs index b176f51ae6b32..99fba0e40c870 100644 --- a/datafusion/spark/src/function/datetime/date_add.rs +++ b/datafusion/spark/src/function/datetime/date_add.rs @@ -83,12 +83,7 @@ impl ScalarUDFImpl for SparkDateAdd { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let nullable = args.arg_fields.iter().any(|f| f.is_nullable()) - || args - .scalar_arguments - .iter() - .any(|arg| matches!(arg, Some(sv) if sv.is_null())); - + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); Ok(Arc::new(Field::new( self.name(), DataType::Date32, @@ -155,7 +150,6 @@ fn spark_date_add(args: &[ArrayRef]) -> Result { mod tests { use super::*; use arrow::datatypes::Field; - use datafusion_common::ScalarValue; #[test] fn test_date_add_non_nullable_inputs() { @@ -194,25 +188,4 @@ mod tests { assert_eq!(ret_field.data_type(), &DataType::Date32); assert!(ret_field.is_nullable()); } - - #[test] - fn test_date_add_null_scalar() { - let func = SparkDateAdd::new(); - let args = &[ - Arc::new(Field::new("date", DataType::Date32, false)), - Arc::new(Field::new("num", DataType::Int32, false)), - ]; - - let null_scalar = ScalarValue::Int32(None); - - let ret_field = func - .return_field_from_args(ReturnFieldArgs { - arg_fields: args, - scalar_arguments: &[None, Some(&null_scalar)], - }) - .unwrap(); - - assert_eq!(ret_field.data_type(), &DataType::Date32); - assert!(ret_field.is_nullable()); - } } diff --git a/datafusion/spark/src/function/datetime/date_sub.rs b/datafusion/spark/src/function/datetime/date_sub.rs index 7e56670f17d22..49d2d97246d72 100644 --- a/datafusion/spark/src/function/datetime/date_sub.rs +++ b/datafusion/spark/src/function/datetime/date_sub.rs @@ -76,12 +76,7 @@ impl ScalarUDFImpl for SparkDateSub { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let nullable = args.arg_fields.iter().any(|f| f.is_nullable()) - || args - .scalar_arguments - .iter() - .any(|arg| matches!(arg, Some(sv) if sv.is_null())); - + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); Ok(Arc::new(Field::new( self.name(), DataType::Date32, @@ -152,7 +147,6 @@ fn spark_date_sub(args: &[ArrayRef]) -> Result { #[cfg(test)] mod tests { use super::*; - use datafusion_common::ScalarValue; #[test] fn test_date_sub_nullability_non_nullable_args() { @@ -187,22 +181,4 @@ mod tests { assert!(result.is_nullable()); assert_eq!(result.data_type(), &DataType::Date32); } - - #[test] - fn test_date_sub_nullability_scalar_null_argument() { - let udf = SparkDateSub::new(); - let date_field = Arc::new(Field::new("d", DataType::Date32, false)); - let days_field = Arc::new(Field::new("n", DataType::Int32, false)); - let null_scalar = ScalarValue::Int32(None); - - let result = udf - .return_field_from_args(ReturnFieldArgs { - arg_fields: &[date_field, days_field], - scalar_arguments: &[None, Some(&null_scalar)], - }) - .unwrap(); - - assert!(result.is_nullable()); - assert_eq!(result.data_type(), &DataType::Date32); - } }