From cf13aad926dfaf3812c77c8acd1c794fd3cffe86 Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Wed, 23 Apr 2025 22:24:41 +0900 Subject: [PATCH 01/13] [python-package]add_features_from with PyArrow Table incorrectly frees raw data despite free_raw_data=False (#6891) --- python-package/lightgbm/basic.py | 47 ++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 2029ec7c1cff..4f1041bd952a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3293,6 +3293,8 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]: self.data = self.data[self.used_indices, :] elif isinstance(self.data, Sequence): self.data = self.data[self.used_indices] + elif isinstance(self.data, pa_Table): + self.data = self.data.take(self.used_indices) elif _is_list_of_sequences(self.data) and len(self.data) > 0: self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices))) else: @@ -3523,6 +3525,51 @@ def add_features_from(self, other: "Dataset") -> "Dataset": self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.to_numpy()))) else: self.data = None + elif isinstance(self.data, pa_Table): + if not PYARROW_INSTALLED: + raise LightGBMError( + "Cannot add features to pyarrow.Table type of raw data " + "without pyarrow installed. " + "Install pyarrow and restart your session." + ) + if isinstance(other.data, np.ndarray): + self.data = pa_Table.from_arrays( + [ + *self.data.columns, + *[pa_Array.from_numpy(other.data[:, i]) for i in range(other.data.shape[1])], + ] + ) + elif isinstance(other.data, scipy.sparse.spmatrix): + other_array = other.data.toarray() + self.data = pa_Table.from_arrays( + [ + *self.data.columns, + *[pa_Array.from_numpy(other_array[:, i]) for i in range(other_array.shape[1])], + ] + ) + elif isinstance(other.data, pd_DataFrame): + self.data = pa_Table.from_arrays( + [ + *self.data.columns, + *[ + pa_Array.from_numpy(other.data.iloc[:, i].values) + for i in range(len(other.data.columns)) + ], + ] + ) + elif isinstance(other.data, dt_DataTable): + _emit_datatable_deprecation_warning() + other_array = other.data.to_numpy() + self.data = pa_Table.from_arrays( + [ + *self.data.columns, + *[pa_Array.from_numpy(other_array[:, i]) for i in range(other_array.shape[1])], + ] + ) + elif isinstance(other.data, pa_Table): + self.data = pa_Table.from_arrays([*self.data.columns, *other.data.columns]) + else: + self.data = None else: self.data = None if self.data is None: From 99ee66f9bf393f8cfdd9dd9f730c4b8c2db59f19 Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Wed, 23 Apr 2025 23:22:52 +0900 Subject: [PATCH 02/13] [python-package] add PyArrow Table case to test_add_features_from_different_sources (#6891) --- tests/python_package_test/test_basic.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index f34be5cc1574..b7deacede4d5 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -7,6 +7,7 @@ from pathlib import Path import numpy as np +import pyarrow as pa import pytest from scipy import sparse from sklearn.datasets import dump_svmlight_file, load_svmlight_file, make_blobs @@ -345,7 +346,15 @@ def test_add_features_from_different_sources(rng): n_row = 100 n_col = 5 X = rng.uniform(size=(n_row, n_col)) - xxs = [X, sparse.csr_matrix(X), pd.DataFrame(X)] + xxs = [ + X, + sparse.csr_matrix(X), + pd.DataFrame(X), + pa.Table.from_arrays( + [pa.array(X[:, i]) for i in range(X.shape[1])], names=[f"col_{i}" for i in range(X.shape[1])] + ), + ] + names = [f"col_{i}" for i in range(n_col)] seq = _create_sequence_from_ndarray(X, 1, 30) seq_ds = lgb.Dataset(seq, feature_name=names, free_raw_data=False).construct() From a3256a35b6258011f15fe97105efe1066732e215 Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Thu, 24 Apr 2025 09:21:31 +0900 Subject: [PATCH 03/13] [python-package] fix handling and tests for PyArrow Table input in add_features_from method (#6891) --- python-package/lightgbm/basic.py | 109 +++++++++++++++++++++--- tests/python_package_test/test_basic.py | 19 ++++- 2 files changed, 112 insertions(+), 16 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 4f1041bd952a..d54e7bf545c7 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -46,6 +46,9 @@ pd_Series, ) +if PYARROW_INSTALLED: + import pyarrow as pa + if TYPE_CHECKING: from typing import Literal @@ -3482,6 +3485,22 @@ def add_features_from(self, other: "Dataset") -> "Dataset": elif isinstance(other.data, dt_DataTable): _emit_datatable_deprecation_warning() self.data = np.hstack((self.data, other.data.to_numpy())) + elif isinstance(other.data, pa_Table): + if not PYARROW_INSTALLED: + raise LightGBMError( + "Cannot add features to pyarrow.Table type of raw data " + "without pyarrow installed. " + "Install pyarrow and restart your session." + ) + else: + self.data = np.hstack( + ( + self.data, + np.column_stack( + [other.data.column(i).to_numpy() for i in range(len(other.data.column_names))] + ), + ) + ) else: self.data = None elif isinstance(self.data, scipy.sparse.spmatrix): @@ -3493,6 +3512,23 @@ def add_features_from(self, other: "Dataset") -> "Dataset": elif isinstance(other.data, dt_DataTable): _emit_datatable_deprecation_warning() self.data = scipy.sparse.hstack((self.data, other.data.to_numpy()), format=sparse_format) + elif isinstance(other.data, pa_Table): + if not PYARROW_INSTALLED: + raise LightGBMError( + "Cannot add features to pyarrow.Table type of raw data " + "without pyarrow installed. " + "Install pyarrow and restart your session." + ) + else: + self.data = scipy.sparse.hstack( + ( + self.data, + np.column_stack( + [other.data.column(i).to_numpy() for i in range(len(other.data.column_names))] + ), + ), + format=sparse_format, + ) else: self.data = None elif isinstance(self.data, pd_DataFrame): @@ -3511,6 +3547,27 @@ def add_features_from(self, other: "Dataset") -> "Dataset": elif isinstance(other.data, dt_DataTable): _emit_datatable_deprecation_warning() self.data = concat((self.data, pd_DataFrame(other.data.to_numpy())), axis=1, ignore_index=True) + elif isinstance(other.data, pa_Table): + if not PYARROW_INSTALLED: + raise LightGBMError( + "Cannot add features to pyarrow.Table type of raw data " + "without pyarrow installed. " + "Install pyarrow and restart your session." + ) + else: + self.data = concat( + ( + self.data, + pd_DataFrame( + { + other.data.column_names[i]: other.data.column(i).to_numpy() + for i in range(len(other.data.column_names)) + } + ), + ), + axis=1, + ignore_index=True, + ) else: self.data = None elif isinstance(self.data, dt_DataTable): @@ -3523,6 +3580,24 @@ def add_features_from(self, other: "Dataset") -> "Dataset": self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.values))) elif isinstance(other.data, dt_DataTable): self.data = dt_DataTable(np.hstack((self.data.to_numpy(), other.data.to_numpy()))) + elif isinstance(other.data, pa_Table): + if not PYARROW_INSTALLED: + raise LightGBMError( + "Cannot add features to pyarrow.Table type of raw data " + "without pyarrow installed. " + "Install pyarrow and restart your session." + ) + else: + self.data = dt_DataTable( + np.hstack( + ( + self.data.to_numpy(), + np.column_stack( + [other.data.column(i).to_numpy() for i in range(len(other.data.column_names))] + ), + ) + ) + ) else: self.data = None elif isinstance(self.data, pa_Table): @@ -3536,26 +3611,32 @@ def add_features_from(self, other: "Dataset") -> "Dataset": self.data = pa_Table.from_arrays( [ *self.data.columns, - *[pa_Array.from_numpy(other.data[:, i]) for i in range(other.data.shape[1])], - ] + *[pa.array(other.data[:, i]) for i in range(other.data.shape[1])], + ], + names=[ + *self.data.column_names, + *[f"D{len(self.data.column_names) + i + 1}" for i in range(other.data.shape[1])], + ], ) elif isinstance(other.data, scipy.sparse.spmatrix): other_array = other.data.toarray() self.data = pa_Table.from_arrays( [ *self.data.columns, - *[pa_Array.from_numpy(other_array[:, i]) for i in range(other_array.shape[1])], - ] + *[pa.array(other_array[:, i]) for i in range(other_array.shape[1])], + ], + names=[ + *self.data.column_names, + *[f"D{len(self.data.column_names) + i + 1}" for i in range(other_array.shape[1])], + ], ) elif isinstance(other.data, pd_DataFrame): self.data = pa_Table.from_arrays( [ *self.data.columns, - *[ - pa_Array.from_numpy(other.data.iloc[:, i].values) - for i in range(len(other.data.columns)) - ], - ] + *[pa.array(other.data.iloc[:, i].values) for i in range(len(other.data.columns))], + ], + names=[*self.data.column_names, *map(str, other.data.columns.tolist())], ) elif isinstance(other.data, dt_DataTable): _emit_datatable_deprecation_warning() @@ -3563,11 +3644,15 @@ def add_features_from(self, other: "Dataset") -> "Dataset": self.data = pa_Table.from_arrays( [ *self.data.columns, - *[pa_Array.from_numpy(other_array[:, i]) for i in range(other_array.shape[1])], - ] + *[pa.array(other_array[:, i]) for i in range(other_array.shape[1])], + ], + names=[*self.data.column_names, *other.data.names], ) elif isinstance(other.data, pa_Table): - self.data = pa_Table.from_arrays([*self.data.columns, *other.data.columns]) + self.data = pa_Table.from_arrays( + [*self.data.columns, *other.data.columns], + names=[*self.data.column_names, *other.data.column_names], + ) else: self.data = None else: diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index b7deacede4d5..f6375a2e0cf2 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -7,7 +7,6 @@ from pathlib import Path import numpy as np -import pyarrow as pa import pytest from scipy import sparse from sklearn.datasets import dump_svmlight_file, load_svmlight_file, make_blobs @@ -18,6 +17,15 @@ from .utils import dummy_obj, load_breast_cancer, mse_obj, np_assert_array_equal +if getenv("ALLOW_SKIP_ARROW_TESTS") == "1": + pa = pytest.importorskip("pyarrow") +else: + import pyarrow as pa # type: ignore + + assert lgb.compat.PYARROW_INSTALLED is True, ( + "'pyarrow' and its dependencies must be installed to run the arrow tests" + ) + def test_basic(tmp_path): X_train, X_test, y_train, y_test = train_test_split( @@ -350,10 +358,13 @@ def test_add_features_from_different_sources(rng): X, sparse.csr_matrix(X), pd.DataFrame(X), - pa.Table.from_arrays( - [pa.array(X[:, i]) for i in range(X.shape[1])], names=[f"col_{i}" for i in range(X.shape[1])] - ), ] + if getenv("ALLOW_SKIP_ARROW_TESTS") != "1": + xxs.append( + pa.Table.from_arrays( + [pa.array(X[:, i]) for i in range(X.shape[1])], names=[f"D{i}" for i in range(X.shape[1])] + ) + ) names = [f"col_{i}" for i in range(n_col)] seq = _create_sequence_from_ndarray(X, 1, 30) From 01bc66804088e51db74ca2632d55497988d4df71 Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Mon, 28 Apr 2025 14:19:16 +0900 Subject: [PATCH 04/13] delete unnecessary-else --- python-package/lightgbm/basic.py | 74 +++++++++++++++----------------- 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 3b9a46bebfdb..6074d3f1943b 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3492,15 +3492,14 @@ def add_features_from(self, other: "Dataset") -> "Dataset": "without pyarrow installed. " "Install pyarrow and restart your session." ) - else: - self.data = np.hstack( - ( - self.data, - np.column_stack( - [other.data.column(i).to_numpy() for i in range(len(other.data.column_names))] - ), - ) + self.data = np.hstack( + ( + self.data, + np.column_stack( + [other.data.column(i).to_numpy() for i in range(len(other.data.column_names))] + ), ) + ) else: self.data = None elif isinstance(self.data, scipy.sparse.spmatrix): @@ -3519,16 +3518,15 @@ def add_features_from(self, other: "Dataset") -> "Dataset": "without pyarrow installed. " "Install pyarrow and restart your session." ) - else: - self.data = scipy.sparse.hstack( - ( - self.data, - np.column_stack( - [other.data.column(i).to_numpy() for i in range(len(other.data.column_names))] - ), + self.data = scipy.sparse.hstack( + ( + self.data, + np.column_stack( + [other.data.column(i).to_numpy() for i in range(len(other.data.column_names))] ), - format=sparse_format, - ) + ), + format=sparse_format, + ) else: self.data = None elif isinstance(self.data, pd_DataFrame): @@ -3554,20 +3552,19 @@ def add_features_from(self, other: "Dataset") -> "Dataset": "without pyarrow installed. " "Install pyarrow and restart your session." ) - else: - self.data = concat( - ( - self.data, - pd_DataFrame( - { - other.data.column_names[i]: other.data.column(i).to_numpy() - for i in range(len(other.data.column_names)) - } - ), + self.data = concat( + ( + self.data, + pd_DataFrame( + { + other.data.column_names[i]: other.data.column(i).to_numpy() + for i in range(len(other.data.column_names)) + } ), - axis=1, - ignore_index=True, - ) + ), + axis=1, + ignore_index=True, + ) else: self.data = None elif isinstance(self.data, dt_DataTable): @@ -3587,17 +3584,16 @@ def add_features_from(self, other: "Dataset") -> "Dataset": "without pyarrow installed. " "Install pyarrow and restart your session." ) - else: - self.data = dt_DataTable( - np.hstack( - ( - self.data.to_numpy(), - np.column_stack( - [other.data.column(i).to_numpy() for i in range(len(other.data.column_names))] - ), - ) + self.data = dt_DataTable( + np.hstack( + ( + self.data.to_numpy(), + np.column_stack( + [other.data.column(i).to_numpy() for i in range(len(other.data.column_names))] + ), ) ) + ) else: self.data = None elif isinstance(self.data, pa_Table): From 219e61c47e174d83677ab61ed6da8931959c2151 Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Mon, 28 Apr 2025 22:53:36 +0900 Subject: [PATCH 05/13] [python-package]add test for pyarrow table in Dataset.get_data() --- tests/python_package_test/test_arrow.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 93ab21021ba8..0cb9769a1805 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -456,6 +456,20 @@ def test_arrow_feature_name_manual(): assert booster.feature_name() == ["c", "d"] +def test_get_data_arrow_table(): + original_table = generate_simple_arrow_table() + dataset = lgb.Dataset(original_table, free_raw_data=False) + dataset.construct() + + returned_data = dataset.get_data() + assert isinstance(returned_data, pa.Table) + assert returned_data.schema == original_table.schema + assert returned_data.shape == original_table.shape + + for column in original_table.column_names: + assert original_table[column].equals(returned_data[column]) + + def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi): with pytest.raises( lgb.basic.LightGBMError, match="Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed." From d4de30974ceba204c353fbe41685c44825adb0d3 Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Mon, 28 Apr 2025 22:59:02 +0900 Subject: [PATCH 06/13] [python-package]add test for add_features_from with pyarrow tables --- tests/python_package_test/test_arrow.py | 33 +++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 0cb9769a1805..db12218cbf0f 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -470,6 +470,39 @@ def test_get_data_arrow_table(): assert original_table[column].equals(returned_data[column]) +def test_add_features_from_arrow_table(): + table1 = pa.Table.from_arrays( + [pa.array([1, 2, 3, 4, 5], type=pa.int32()), pa.array([0.1, 0.2, 0.3, 0.4, 0.5], type=pa.float32())], + names=["feature1", "feature2"], + ) + + table2 = pa.Table.from_arrays( + [ + pa.array([10, 20, 30, 40, 50], type=pa.int64()), + pa.array([1.1, 1.2, 1.3, 1.4, 1.5], type=pa.float64()), + pa.array([True, False, True, False, True], type=pa.bool_()), + ], + names=["feature3", "feature4", "feature5"], + ) + + dataset1 = lgb.Dataset(table1, free_raw_data=False) + dataset2 = lgb.Dataset(table2, free_raw_data=False) + + dataset1.construct() + dataset2.construct() + + dataset1.add_features_from(dataset2) + combined_data = dataset1.get_data() + assert isinstance(combined_data, pa.Table) + assert combined_data.num_columns == table1.num_columns + table2.num_columns + assert set(combined_data.column_names) == set(table1.column_names + table2.column_names) + assert combined_data.num_rows == table1.num_rows + for column in table1.column_names: + assert combined_data[column].equals(table1[column]) + for column in table2.column_names: + assert combined_data[column].equals(table2[column]) + + def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi): with pytest.raises( lgb.basic.LightGBMError, match="Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed." From 22ec3149ef4c7281242ab0dc17c272f319459cef Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Fri, 9 May 2025 23:34:53 +0900 Subject: [PATCH 07/13] [python-package] add PyArrow Table to get_data --- python-package/lightgbm/basic.py | 2 ++ tests/python_package_test/test_arrow.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index c1781eb2283c..0147a6065572 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3267,6 +3267,8 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]: self.data = self.data.iloc[self.used_indices].copy() elif isinstance(self.data, Sequence): self.data = self.data[self.used_indices] + elif isinstance(self.data, pa_Table): + self.data = self.data.take(self.used_indices) elif _is_list_of_sequences(self.data) and len(self.data) > 0: self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices))) else: diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 93ab21021ba8..2fbfcfa2e540 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -456,6 +456,31 @@ def test_arrow_feature_name_manual(): assert booster.feature_name() == ["c", "d"] +def test_get_data_arrow_table(): + original_table = generate_simple_arrow_table() + dataset = lgb.Dataset(original_table, free_raw_data=False) + dataset.construct() + + returned_data = dataset.get_data() + assert isinstance(returned_data, pa.Table) + assert returned_data.schema == original_table.schema + assert returned_data.shape == original_table.shape + + for column_name in original_table.column_names: + original_column = original_table[column_name] + returned_column = returned_data[column_name] + + assert original_column.type == returned_column.type + assert original_column.num_chunks == returned_column.num_chunks + + for i in range(original_column.num_chunks): + original_chunk = original_column.chunk(i) + returned_chunk = returned_column.chunk(i) + assert original_chunk.equals(returned_chunk) + + assert original_column.equals(returned_column) + + def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi): with pytest.raises( lgb.basic.LightGBMError, match="Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed." From c7594c8969c3051bef8f94ab615f0772500403bd Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Sat, 10 May 2025 23:59:29 +0900 Subject: [PATCH 08/13] [python-package] add test for subset of PyArrow table dataset --- tests/python_package_test/test_arrow.py | 27 +++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 2fbfcfa2e540..4474c06431d5 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -481,6 +481,33 @@ def test_get_data_arrow_table(): assert original_column.equals(returned_column) +def test_get_data_arrow_table_subset(rng): + original_table = generate_simple_arrow_table() + dataset = lgb.Dataset(original_table, free_raw_data=False) + dataset.construct() + + used_indices = rng.choice(a=original_table.shape[0], size=original_table.shape[0] // 3, replace=False) + subset_dataset = dataset.subset(used_indices).construct() + expected_subset = original_table.take(used_indices) + subset_data = subset_dataset.get_data() + assert subset_data.schema == expected_subset.schema + assert subset_data.shape == expected_subset.shape + assert len(subset_data) == len(used_indices) + for column_name in expected_subset.column_names: + original_column = expected_subset[column_name] + returned_column = subset_data[column_name] + + assert original_column.type == returned_column.type + assert original_column.num_chunks == returned_column.num_chunks + + for i in range(original_column.num_chunks): + original_chunk = original_column.chunk(i) + returned_chunk = returned_column.chunk(i) + assert original_chunk.equals(returned_chunk) + + assert original_column.equals(returned_column) + + def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi): with pytest.raises( lgb.basic.LightGBMError, match="Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed." From b5395a0e8a203fd6f3300e8a64ca343571e8845d Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Fri, 30 May 2025 09:31:35 +0900 Subject: [PATCH 09/13] [python-package] improve PyArrow table subset tests for null values and chunking --- tests/python_package_test/test_arrow.py | 46 +++++++++++++++---------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 4474c06431d5..1743388dc386 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -118,7 +118,7 @@ def generate_random_arrow_array( chunks = [chunk for chunk in chunks if len(chunk) > 0] # Turn chunks into array - return pa.chunked_array([data], type=pa.float32()) + return pa.chunked_array(chunks, type=pa.float32()) def dummy_dataset_params() -> Dict[str, Any]: @@ -456,6 +456,15 @@ def test_arrow_feature_name_manual(): assert booster.feature_name() == ["c", "d"] +def safe_array_equal_with_nulls(arr1: pa.ChunkedArray, arr2: pa.ChunkedArray) -> bool: + if len(arr1) != len(arr2): + return False + + np1 = arr1.to_numpy(zero_copy_only=False) + np2 = arr2.to_numpy(zero_copy_only=False) + return np.array_equal(np1, np2, equal_nan=True) + + def test_get_data_arrow_table(): original_table = generate_simple_arrow_table() dataset = lgb.Dataset(original_table, free_raw_data=False) @@ -474,38 +483,37 @@ def test_get_data_arrow_table(): assert original_column.num_chunks == returned_column.num_chunks for i in range(original_column.num_chunks): - original_chunk = original_column.chunk(i) - returned_chunk = returned_column.chunk(i) - assert original_chunk.equals(returned_chunk) + original_chunk_array = pa.chunked_array([original_column.chunk(i)]) + returned_chunk_array = pa.chunked_array([returned_column.chunk(i)]) + assert safe_array_equal_with_nulls(original_chunk_array, returned_chunk_array) - assert original_column.equals(returned_column) + assert safe_array_equal_with_nulls(original_column, returned_column) def test_get_data_arrow_table_subset(rng): - original_table = generate_simple_arrow_table() + original_table = generate_random_arrow_table(num_columns=3, num_datapoints=1000, seed=42) dataset = lgb.Dataset(original_table, free_raw_data=False) dataset.construct() - used_indices = rng.choice(a=original_table.shape[0], size=original_table.shape[0] // 3, replace=False) + subset_size = 100 + used_indices = rng.choice(a=original_table.shape[0], size=subset_size, replace=False) + used_indices = sorted(used_indices) + subset_dataset = dataset.subset(used_indices).construct() expected_subset = original_table.take(used_indices) subset_data = subset_dataset.get_data() + + assert isinstance(subset_data, pa.Table) assert subset_data.schema == expected_subset.schema assert subset_data.shape == expected_subset.shape assert len(subset_data) == len(used_indices) - for column_name in expected_subset.column_names: - original_column = expected_subset[column_name] - returned_column = subset_data[column_name] + assert subset_data.shape == (subset_size, 3) - assert original_column.type == returned_column.type - assert original_column.num_chunks == returned_column.num_chunks - - for i in range(original_column.num_chunks): - original_chunk = original_column.chunk(i) - returned_chunk = returned_column.chunk(i) - assert original_chunk.equals(returned_chunk) - - assert original_column.equals(returned_column) + for column_name in expected_subset.column_names: + expected_col = expected_subset[column_name] + returned_col = subset_data[column_name] + assert expected_col.type == returned_col.type + assert safe_array_equal_with_nulls(expected_col, returned_col) def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi): From e569ac552c7a6894a4a49fb85cc781317254a568 Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Mon, 1 Sep 2025 13:49:03 +0900 Subject: [PATCH 10/13] [python-package]avoid TypeError: ChunkedArray.to_numpy() takes no keyword arguments --- tests/python_package_test/test_arrow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index ce061f35abf2..b4aec9786774 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -460,8 +460,8 @@ def safe_array_equal_with_nulls(arr1: pa.ChunkedArray, arr2: pa.ChunkedArray) -> if len(arr1) != len(arr2): return False - np1 = arr1.to_numpy(zero_copy_only=False) - np2 = arr2.to_numpy(zero_copy_only=False) + np1 = arr1.to_numpy() + np2 = arr2.to_numpy() return np.array_equal(np1, np2, equal_nan=True) From e826a4f50aaaed0f9ed5ab0e331a4558057fa25d Mon Sep 17 00:00:00 2001 From: Yusuke Horibe <40166140+suk1yak1@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:08:57 +0900 Subject: [PATCH 11/13] [python-package]Move final assertion before for loop to fail faster and group related comparisons Co-authored-by: James Lamb --- tests/python_package_test/test_arrow.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index b4aec9786774..265217ceedcb 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -481,14 +481,13 @@ def test_get_data_arrow_table(): assert original_column.type == returned_column.type assert original_column.num_chunks == returned_column.num_chunks + assert safe_array_equal_with_nulls(original_column, returned_column) for i in range(original_column.num_chunks): original_chunk_array = pa.chunked_array([original_column.chunk(i)]) returned_chunk_array = pa.chunked_array([returned_column.chunk(i)]) assert safe_array_equal_with_nulls(original_chunk_array, returned_chunk_array) - assert safe_array_equal_with_nulls(original_column, returned_column) - def test_get_data_arrow_table_subset(rng): original_table = generate_random_arrow_table(num_columns=3, num_datapoints=1000, seed=42) From 17dbc43a9444202e8459ebb082fc125628513acb Mon Sep 17 00:00:00 2001 From: Yusuke Horibe <40166140+suk1yak1@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:10:12 +0900 Subject: [PATCH 12/13] [python-package]Rename test helper and add docstring to clarify purpose Co-authored-by: James Lamb --- tests/python_package_test/test_arrow.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 265217ceedcb..a1229b71ae0b 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -456,7 +456,12 @@ def test_arrow_feature_name_manual(): assert booster.feature_name() == ["c", "d"] -def safe_array_equal_with_nulls(arr1: pa.ChunkedArray, arr2: pa.ChunkedArray) -> bool: +def pyarrow_array_equal(arr1: pa.ChunkedArray, arr2: pa.ChunkedArray) -> bool: + """Similar to ``np.array_equal()``, but for ``pyarrow.Array`` objects. + + ``pyarrow.Array`` objects with identical values do not compare equal if any of those + values are nulls. This function treats them as equal. + """ if len(arr1) != len(arr2): return False From 8135e92eba56037463e74a0080dadf16d13d3c2d Mon Sep 17 00:00:00 2001 From: suk1yak1 Date: Fri, 12 Sep 2025 15:25:48 +0900 Subject: [PATCH 13/13] [python-package]Rename test helper --- tests/python_package_test/test_arrow.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index a1229b71ae0b..1052b5fb7388 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -458,7 +458,7 @@ def test_arrow_feature_name_manual(): def pyarrow_array_equal(arr1: pa.ChunkedArray, arr2: pa.ChunkedArray) -> bool: """Similar to ``np.array_equal()``, but for ``pyarrow.Array`` objects. - + ``pyarrow.Array`` objects with identical values do not compare equal if any of those values are nulls. This function treats them as equal. """ @@ -486,12 +486,12 @@ def test_get_data_arrow_table(): assert original_column.type == returned_column.type assert original_column.num_chunks == returned_column.num_chunks - assert safe_array_equal_with_nulls(original_column, returned_column) + assert pyarrow_array_equal(original_column, returned_column) for i in range(original_column.num_chunks): original_chunk_array = pa.chunked_array([original_column.chunk(i)]) returned_chunk_array = pa.chunked_array([returned_column.chunk(i)]) - assert safe_array_equal_with_nulls(original_chunk_array, returned_chunk_array) + assert pyarrow_array_equal(original_chunk_array, returned_chunk_array) def test_get_data_arrow_table_subset(rng): @@ -517,7 +517,7 @@ def test_get_data_arrow_table_subset(rng): expected_col = expected_subset[column_name] returned_col = subset_data[column_name] assert expected_col.type == returned_col.type - assert safe_array_equal_with_nulls(expected_col, returned_col) + assert pyarrow_array_equal(expected_col, returned_col) def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi):