Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions vortex-array/src/arrays/fixed_size_list/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_buffer::BitBufferMut;
use vortex_buffer::BufferMut;
use vortex_dtype::IntegerPType;
use vortex_dtype::Nullability;
use vortex_dtype::match_each_integer_ptype;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
Expand All @@ -16,8 +16,6 @@ use crate::ToCanonical;
use crate::arrays::FixedSizeListArray;
use crate::arrays::FixedSizeListVTable;
use crate::arrays::PrimitiveArray;
use crate::builders::ArrayBuilder;
use crate::builders::PrimitiveBuilder;
use crate::compute::TakeKernel;
use crate::compute::TakeKernelAdapter;
use crate::compute::{self};
Expand Down Expand Up @@ -92,8 +90,7 @@ fn take_non_nullable_fsl<I: IntegerPType>(
let new_len = indices.len();

// Build the element indices directly without validity tracking.
let mut elements_indices =
PrimitiveBuilder::<I>::with_capacity(Nullability::NonNullable, new_len * list_size);
let mut elements_indices = BufferMut::<I>::with_capacity(new_len * list_size);

// Build the element indices for each list.
for data_idx in indices {
Expand All @@ -106,14 +103,17 @@ fn take_non_nullable_fsl<I: IntegerPType>(

// Expand the list into individual element indices.
for i in list_start..list_end {
elements_indices.append_value(I::from_usize(i).vortex_expect("i < list_end"))
unsafe {
elements_indices.push_unchecked(I::from_usize(i).vortex_expect("i < list_end"))
};
}
}

let elements_indices = elements_indices.finish();
let elements_indices = elements_indices.freeze();
debug_assert_eq!(elements_indices.len(), new_len * list_size);

let new_elements = compute::take(array.elements(), elements_indices.as_ref())?;
let elements_indices_array = PrimitiveArray::new(elements_indices, Validity::NonNullable);
let new_elements = compute::take(array.elements(), elements_indices_array.as_ref())?;
debug_assert_eq!(new_elements.len(), new_len * list_size);

// Both inputs are non-nullable, so the result is non-nullable.
Expand Down Expand Up @@ -142,8 +142,7 @@ fn take_nullable_fsl<I: IntegerPType>(

// We must use placeholder zeros for null lists to maintain the array length without
// propagating nullability to the element array's take operation.
let mut elements_indices =
PrimitiveBuilder::<I>::with_capacity(Nullability::NonNullable, new_len * list_size);
let mut elements_indices = BufferMut::<I>::with_capacity(new_len * list_size);
let mut new_validity_builder = BitBufferMut::with_capacity(new_len);

// Build the element indices while tracking which lists are null.
Expand All @@ -158,7 +157,7 @@ fn take_nullable_fsl<I: IntegerPType>(
if !is_index_valid || !array_validity.value(data_idx) {
// Append placeholder zeros for null lists. These will be masked by the validity array.
// We cannot use append_nulls here as explained above.
elements_indices.append_zeros(list_size);
unsafe { elements_indices.push_n_unchecked(I::zero(), list_size) };
new_validity_builder.append(false);
} else {
// Append the actual element indices for this list.
Expand All @@ -167,17 +166,20 @@ fn take_nullable_fsl<I: IntegerPType>(

// Expand the list into individual element indices.
for i in list_start..list_end {
elements_indices.append_value(I::from_usize(i).vortex_expect("i < list_end"))
unsafe {
elements_indices.push_unchecked(I::from_usize(i).vortex_expect("i < list_end"))
};
}

new_validity_builder.append(true);
}
}

let elements_indices = elements_indices.finish();
let elements_indices = elements_indices.freeze();
debug_assert_eq!(elements_indices.len(), new_len * list_size);

let new_elements = compute::take(array.elements(), elements_indices.as_ref())?;
let elements_indices_array = PrimitiveArray::new(elements_indices, Validity::NonNullable);
let new_elements = compute::take(array.elements(), elements_indices_array.as_ref())?;
debug_assert_eq!(new_elements.len(), new_len * list_size);

// At least one input was nullable, so the result is nullable.
Expand Down
Loading