Skip to content

Commit 9909535

Browse files
authored
feat: add debug assertions before unsafe code blocks (#3655)
Add debug_assert! statements between SAFETY comments and unsafe blocks to catch precondition violations during development and testing. Assertions cover: - Null pointer checks before raw pointer dereference - Index bounds checks before array/bitset access - Initialization checks before accessing global singletons - Alignment checks before aligned pointer writes - Non-negative size/length checks before slice construction
1 parent 6e2a125 commit 9909535

6 files changed

Lines changed: 108 additions & 0 deletions

File tree

native/core/src/execution/jni_api.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,13 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative(
851851
tracing_enabled != JNI_FALSE,
852852
|| {
853853
// SAFETY: JVM unsafe memory allocation is aligned with long.
854+
debug_assert!(address != 0, "sortRowPartitionsNative: null address");
855+
debug_assert!(size >= 0, "sortRowPartitionsNative: negative size {size}");
856+
debug_assert_eq!(
857+
(address as usize) % std::mem::align_of::<i64>(),
858+
0,
859+
"sortRowPartitionsNative: address not aligned to i64"
860+
);
854861
let array =
855862
unsafe { std::slice::from_raw_parts_mut(address as *mut i64, size as usize) };
856863
array.rdxsort();

native/core/src/execution/shuffle/spark_unsafe/list.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ impl SparkUnsafeArray {
5353
pub fn new(addr: i64) -> Self {
5454
// SAFETY: addr points to valid Spark UnsafeArray data from the JVM.
5555
// The first 8 bytes contain the element count as a little-endian i64.
56+
debug_assert!(addr != 0, "SparkUnsafeArray::new: null address");
5657
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
5758
let num_elements = i64::from_le_bytes(slice.try_into().unwrap());
5859

@@ -87,6 +88,11 @@ impl SparkUnsafeArray {
8788
// SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts
8889
// at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures
8990
// index < num_elements, so word_offset is within the bitset region.
91+
debug_assert!(
92+
index < self.num_elements,
93+
"is_null_at: index {index} >= num_elements {}",
94+
self.num_elements
95+
);
9096
unsafe {
9197
let mask: i64 = 1i64 << (index & 0x3f);
9298
let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64;

native/core/src/execution/shuffle/spark_unsafe/map.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ impl SparkUnsafeMap {
3232
pub(crate) fn new(addr: i64, size: i32) -> Self {
3333
// SAFETY: addr points to valid Spark UnsafeMap data from the JVM.
3434
// The first 8 bytes contain the key array size as a little-endian i64.
35+
debug_assert!(addr != 0, "SparkUnsafeMap::new: null address");
36+
debug_assert!(size >= 0, "SparkUnsafeMap::new: negative size {size}");
3537
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
3638
let key_array_size = i64::from_le_bytes(slice.try_into().unwrap());
3739

native/core/src/execution/shuffle/spark_unsafe/row.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,18 @@ pub trait SparkUnsafeObject {
9292
let addr = self.get_element_offset(index, 1);
9393
// SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region.
9494
// The caller ensures index is within bounds.
95+
debug_assert!(
96+
!addr.is_null(),
97+
"get_boolean: null pointer at index {index}"
98+
);
9599
unsafe { *addr != 0 }
96100
}
97101

98102
/// Returns byte value at the given index of the object.
99103
fn get_byte(&self, index: usize) -> i8 {
100104
let addr = self.get_element_offset(index, 1);
101105
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
106+
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
102107
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) };
103108
i8::from_le_bytes(slice.try_into().unwrap())
104109
}
@@ -107,6 +112,7 @@ pub trait SparkUnsafeObject {
107112
fn get_short(&self, index: usize) -> i16 {
108113
let addr = self.get_element_offset(index, 2);
109114
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
115+
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
110116
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) };
111117
i16::from_le_bytes(slice.try_into().unwrap())
112118
}
@@ -115,6 +121,7 @@ pub trait SparkUnsafeObject {
115121
fn get_int(&self, index: usize) -> i32 {
116122
let addr = self.get_element_offset(index, 4);
117123
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
124+
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
118125
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
119126
i32::from_le_bytes(slice.try_into().unwrap())
120127
}
@@ -123,6 +130,7 @@ pub trait SparkUnsafeObject {
123130
fn get_long(&self, index: usize) -> i64 {
124131
let addr = self.get_element_offset(index, 8);
125132
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
133+
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
126134
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
127135
i64::from_le_bytes(slice.try_into().unwrap())
128136
}
@@ -131,6 +139,7 @@ pub trait SparkUnsafeObject {
131139
fn get_float(&self, index: usize) -> f32 {
132140
let addr = self.get_element_offset(index, 4);
133141
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
142+
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
134143
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
135144
f32::from_le_bytes(slice.try_into().unwrap())
136145
}
@@ -139,6 +148,7 @@ pub trait SparkUnsafeObject {
139148
fn get_double(&self, index: usize) -> f64 {
140149
let addr = self.get_element_offset(index, 8);
141150
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
151+
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
142152
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
143153
f64::from_le_bytes(slice.try_into().unwrap())
144154
}
@@ -149,6 +159,11 @@ pub trait SparkUnsafeObject {
149159
let addr = self.get_row_addr() + offset as i64;
150160
// SAFETY: addr points to valid UTF-8 string data within the variable-length region.
151161
// Offset and length are read from the fixed-length portion of the row/array.
162+
debug_assert!(addr != 0, "get_string: null address at index {index}");
163+
debug_assert!(
164+
len >= 0,
165+
"get_string: negative length {len} at index {index}"
166+
);
152167
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) };
153168

154169
from_utf8(slice).unwrap()
@@ -160,13 +175,19 @@ pub trait SparkUnsafeObject {
160175
let addr = self.get_row_addr() + offset as i64;
161176
// SAFETY: addr points to valid binary data within the variable-length region.
162177
// Offset and length are read from the fixed-length portion of the row/array.
178+
debug_assert!(addr != 0, "get_binary: null address at index {index}");
179+
debug_assert!(
180+
len >= 0,
181+
"get_binary: negative length {len} at index {index}"
182+
);
163183
unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
164184
}
165185

166186
/// Returns date value at the given index of the object.
167187
fn get_date(&self, index: usize) -> i32 {
168188
let addr = self.get_element_offset(index, 4);
169189
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
190+
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
170191
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
171192
i32::from_le_bytes(slice.try_into().unwrap())
172193
}
@@ -175,6 +196,10 @@ pub trait SparkUnsafeObject {
175196
fn get_timestamp(&self, index: usize) -> i64 {
176197
let addr = self.get_element_offset(index, 8);
177198
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
199+
debug_assert!(
200+
!addr.is_null(),
201+
"get_timestamp: null pointer at index {index}"
202+
);
178203
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
179204
i64::from_le_bytes(slice.try_into().unwrap())
180205
}
@@ -287,6 +312,7 @@ impl SparkUnsafeRow {
287312
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
288313
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
289314
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
315+
debug_assert!(self.row_addr != -1, "is_null_at: row not initialized");
290316
unsafe {
291317
let mask: i64 = 1i64 << (index & 0x3f);
292318
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64;
@@ -300,6 +326,7 @@ impl SparkUnsafeRow {
300326
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
301327
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
302328
// Writing is safe because we have mutable access and the memory is owned by the JVM.
329+
debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized");
303330
unsafe {
304331
let mask: i64 = 1i64 << (index & 0x3f);
305332
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64;
@@ -498,6 +525,18 @@ fn append_columns(
498525
for i in row_start..row_end {
499526
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
500527
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
528+
debug_assert!(
529+
!row_addresses_ptr.is_null(),
530+
"append_columns: null row_addresses_ptr"
531+
);
532+
debug_assert!(
533+
!row_sizes_ptr.is_null(),
534+
"append_columns: null row_sizes_ptr"
535+
);
536+
debug_assert!(
537+
i < row_end,
538+
"append_columns: index {i} out of bounds (row_end={row_end})"
539+
);
501540
let row_addr = unsafe { *row_addresses_ptr.add(i) };
502541
let row_size = unsafe { *row_sizes_ptr.add(i) };
503542
row.point_to(row_addr, row_size);
@@ -630,6 +669,18 @@ fn append_columns(
630669
for i in row_start..row_end {
631670
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
632671
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
672+
debug_assert!(
673+
!row_addresses_ptr.is_null(),
674+
"append_columns: null row_addresses_ptr"
675+
);
676+
debug_assert!(
677+
!row_sizes_ptr.is_null(),
678+
"append_columns: null row_sizes_ptr"
679+
);
680+
debug_assert!(
681+
i < row_end,
682+
"append_columns: index {i} out of bounds (row_end={row_end})"
683+
);
633684
let row_addr = unsafe { *row_addresses_ptr.add(i) };
634685
let row_size = unsafe { *row_sizes_ptr.add(i) };
635686
row.point_to(row_addr, row_size);
@@ -652,6 +703,18 @@ fn append_columns(
652703
for i in row_start..row_end {
653704
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
654705
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
706+
debug_assert!(
707+
!row_addresses_ptr.is_null(),
708+
"append_columns: null row_addresses_ptr"
709+
);
710+
debug_assert!(
711+
!row_sizes_ptr.is_null(),
712+
"append_columns: null row_sizes_ptr"
713+
);
714+
debug_assert!(
715+
i < row_end,
716+
"append_columns: index {i} out of bounds (row_end={row_end})"
717+
);
655718
let row_addr = unsafe { *row_addresses_ptr.add(i) };
656719
let row_size = unsafe { *row_sizes_ptr.add(i) };
657720
row.point_to(row_addr, row_size);
@@ -681,6 +744,18 @@ fn append_columns(
681744
for i in row_start..row_end {
682745
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
683746
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
747+
debug_assert!(
748+
!row_addresses_ptr.is_null(),
749+
"append_columns: null row_addresses_ptr"
750+
);
751+
debug_assert!(
752+
!row_sizes_ptr.is_null(),
753+
"append_columns: null row_sizes_ptr"
754+
);
755+
debug_assert!(
756+
i < row_end,
757+
"append_columns: index {i} out of bounds (row_end={row_end})"
758+
);
684759
let row_addr = unsafe { *row_addresses_ptr.add(i) };
685760
let row_size = unsafe { *row_sizes_ptr.add(i) };
686761
row.point_to(row_addr, row_size);

native/core/src/execution/utils.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,16 @@ impl SparkArrowConvert for ArrayData {
9797
}
9898
} else {
9999
// SAFETY: `array_ptr` and `schema_ptr` are aligned correctly.
100+
debug_assert_eq!(
101+
array_ptr.align_offset(array_align),
102+
0,
103+
"move_to_spark: array_ptr not aligned"
104+
);
105+
debug_assert_eq!(
106+
schema_ptr.align_offset(schema_align),
107+
0,
108+
"move_to_spark: schema_ptr not aligned"
109+
);
100110
unsafe {
101111
std::ptr::write(array_ptr, FFI_ArrowArray::new(self));
102112
std::ptr::write(schema_ptr, FFI_ArrowSchema::try_from(self.data_type())?);

native/core/src/jvm_bridge/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,19 @@ impl JVMClasses<'_> {
263263
}
264264

265265
pub fn get() -> &'static JVMClasses<'static> {
266+
debug_assert!(
267+
JVM_CLASSES.get().is_some(),
268+
"JVMClasses::get: not initialized"
269+
);
266270
unsafe { JVM_CLASSES.get_unchecked() }
267271
}
268272

269273
/// Gets the JNIEnv for the current thread.
270274
pub fn get_env() -> CometResult<AttachGuard<'static>> {
275+
debug_assert!(
276+
JAVA_VM.get().is_some(),
277+
"JVMClasses::get_env: JAVA_VM not initialized"
278+
);
271279
unsafe {
272280
let java_vm = JAVA_VM.get_unchecked();
273281
java_vm.attach_current_thread().map_err(|e| {

0 commit comments

Comments
 (0)