From a3f946bc23ea4d03d3304a03ffb4135c3fda2584 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 13:27:43 -0400 Subject: [PATCH 01/19] recursive --- src/datasets/packaged_modules/hdf5/hdf5.py | 284 +++++++++++---------- tests/packaged_modules/test_hdf5.py | 45 ++-- 2 files changed, 183 insertions(+), 146 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 36858c7dab3..4d89a8a29df 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -76,22 +76,7 @@ def _split_generators(self, dl_manager): if self.info.features is None: for first_file in itertools.chain.from_iterable(files): with h5py.File(first_file, "r") as h5: - dataset_map = _traverse_datasets(h5) - features_dict = {} - - for path, dset in dataset_map.items(): - if _is_complex_dtype(dset.dtype): - complex_features = _create_complex_features(path, dset) - features_dict.update(complex_features) - elif _is_compound_dtype(dset.dtype): - compound_features = _create_compound_features(path, dset) - features_dict.update(compound_features) - elif _is_vlen_string_dtype(dset.dtype): - features_dict[path] = Value("string") - else: - feat = _infer_feature_from_dataset(dset) - features_dict[path] = feat - self.info.features = datasets.Features(features_dict) + self.info.features = _recursive_infer_features(h5) break splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) if self.config.columns is not None and set(self.config.columns) != set(self.info.features): @@ -112,92 +97,22 @@ def _generate_tables(self, files): for file_idx, file in enumerate(itertools.chain.from_iterable(files)): try: with h5py.File(file, "r") as h5: - dataset_map = _traverse_datasets(h5) - if not dataset_map: - logger.warning(f"File '{file}' contains no data, skipping...") + if self.info.features is None: + self.info.features = _recursive_infer_features(h5) + num_rows = _check_dataset_lengths(h5, self.info.features) + if num_rows is None: + logger.warning(f"File {file} contains no data, skipping...") continue - - if self.config.columns is not None: - filtered_dataset_map = { - path: dset for path, dset in dataset_map.items() if path in self.config.columns - } - if not filtered_dataset_map: - logger.warning( - f"No datasets match the specified columns {self.config.columns}, skipping..." - ) - continue - dataset_map = filtered_dataset_map - - # Sanity-check lengths for selected datasets - first_dset = next(iter(dataset_map.values())) - num_rows = first_dset.shape[0] - for path, dset in dataset_map.items(): - if dset.shape[0] != num_rows: - raise ValueError( - f"Dataset '{path}' length {dset.shape[0]} differs from {num_rows} in file '{file}'" - ) effective_batch = batch_size_cfg or self._writer_batch_size or num_rows for start in range(0, num_rows, effective_batch): end = min(start + effective_batch, num_rows) - batch_dict = {} - for path, dset in dataset_map.items(): - arr = dset[start:end] - - # Handle variable-length arrays - if _is_vlen_string_dtype(dset.dtype): - logger.debug( - f"Converting variable-length string data for '{path}' (shape: {arr.shape})" - ) - batch_dict[path] = _convert_vlen_string_to_array(arr) - elif ( - hasattr(dset.dtype, "metadata") - and dset.dtype.metadata - and "vlen" in dset.dtype.metadata - ): - # Handle other variable-length types (non-strings) - pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) - batch_dict[path] = pa_arr - elif _is_complex_dtype(dset.dtype): - batch_dict.update(_convert_complex_to_nested(path, arr, dset)) - elif _is_compound_dtype(dset.dtype): - batch_dict.update(_convert_compound_to_nested(path, arr, dset)) - elif dset.dtype.kind == "O": - raise ValueError( - f"Object dtype dataset '{path}' is not supported. " - f"For variable-length data, please use h5py.vlen_dtype() " - f"when creating the HDF5 file. " - f"See: https://docs.h5py.org/en/stable/special.html#variable-length-strings" - ) - else: - # If any non-batch dimension is zero, emit an unsized pa.list_ - # to avoid creating FixedSizeListArray with list_size=0. - if any(dim == 0 for dim in dset.shape[1:]): - inner_type = pa.from_numpy_dtype(dset.dtype) - pa_arr = pa.array([[] for _ in arr], type=pa.list_(inner_type)) - else: - pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) - batch_dict[path] = pa_arr - pa_table = pa.Table.from_pydict(batch_dict) + pa_table = _recursive_load_data(h5, self.info.features, start, end) yield f"{file_idx}_{start}", self._cast_table(pa_table) except ValueError as e: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise -def _traverse_datasets(h5_obj, prefix: str = "") -> Dict[str, "h5py.Dataset"]: - import h5py - - mapping: Dict[str, h5py.Dataset] = {} - - def collect_datasets(name, obj): - if isinstance(obj, h5py.Dataset): - full_path = f"{prefix}{name}" if prefix else name - mapping[full_path] = obj - - h5_obj.visititems(collect_datasets) - return mapping - - # ┌───────────┐ # │ Complex │ # └───────────┘ @@ -208,37 +123,46 @@ def _is_complex_dtype(dtype: np.dtype) -> bool: return dtype.kind == "c" -def _create_complex_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Features]: +def _create_complex_features(dset: "h5py.Dataset") -> Dict[str, Features]: """Create Features for complex data with real and imaginary parts `real` and `imag`. NOTE: Always uses float64 for the real and imaginary parts. """ logger.info( - f"Complex dataset '{base_path}' (dtype: {dset.dtype}) represented as nested structure with 'real' and 'imag' fields" - ) - nested_features = Features( - { - "real": Value("float64"), - "imag": Value("float64"), - } + f"Complex dataset '{dset.name}' (dtype: {dset.dtype}) represented as nested structure with 'real' and 'imag' fields" ) - return {base_path: nested_features} + + dset_shape = dset.shape[1:] + rank = len(dset_shape) + if rank == 0: + return Features({"real": Value("float64"), "imag": Value("float64")}) + elif rank == 1: + return Features({ + "real": List(Value("float64"), length=dset_shape[0]), + "imag": List(Value("float64"), length=dset_shape[0]) + }) + elif rank <= 5: + array_feature = _sized_arrayxd(rank) + return Features({ + "real": array_feature(shape=dset_shape, dtype="float64"), + "imag": array_feature(shape=dset_shape, dtype="float64") + }) + else: + raise TypeError(f"Complex Array{rank}D not supported. Maximum 5 dimensions allowed.") -def _convert_complex_to_nested(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: - """Convert complex to Features with real and imaginary parts `real` and `imag`.""" - result = {} - def _convert_complex_scalar(complex_val): - """Convert a complex scalar to a dictionary.""" - if complex_val.size == 1: - return {"real": float(complex_val.item().real), "imag": float(complex_val.item().imag)} - else: - # For multi-dimensional arrays, convert to list - return {"real": complex_val.real.tolist(), "imag": complex_val.imag.tolist()} +def _convert_complex_to_nested(arr: np.ndarray) -> Dict[str, pa.Array]: + """Convert complex to Features with real and imaginary parts `real` and `imag`.""" + if arr.size > 1: + data = { + "real": datasets.features.features.numpy_to_pyarrow_listarray(arr.real), + "imag": datasets.features.features.numpy_to_pyarrow_listarray(arr.imag) + } + else: + data = {"real": float(arr.item().real), "imag": float(arr.item().imag)} - result[base_path] = pa.array([_convert_complex_scalar(complex_val) for complex_val in arr]) - return result + return pa.StructArray.from_arrays(data.values(), names=data.keys()) # ┌────────────┐ @@ -252,16 +176,16 @@ def _is_compound_dtype(dtype: np.dtype) -> bool: class _MockDataset: - def __init__(self, dtype): + def __init__(self, dtype, name): self.dtype = dtype - self.names = dtype.names + self.name = name -def _create_compound_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Features]: +def _create_compound_features(dset: "h5py.Dataset") -> Dict[str, Features]: """Create nested features for compound data with field names as keys.""" field_names = list(dset.dtype.names) logger.info( - f"Compound dataset '{base_path}' (dtype: {dset.dtype}) represented as nested Features with fields: {field_names}" + f"Compound dataset '{dset.name}' (dtype: {dset.dtype}) represented as nested Features with fields: {field_names}" ) nested_features_dict = {} @@ -276,19 +200,16 @@ def _create_compound_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, } ) elif _is_compound_dtype(field_dtype): - mock_dset = _MockDataset(field_dtype) - nested_features_dict[field_name] = _create_compound_features(field_name, mock_dset)[field_name] + mock_dset = _MockDataset(field_dtype, f"subfield {field_name} of {dset.name}") + nested_features_dict[field_name] = _create_compound_features(mock_dset) else: nested_features_dict[field_name] = _np_to_pa_to_hf_value(field_dtype) - nested_features = Features(nested_features_dict) - return {base_path: nested_features} + return Features(nested_features_dict) -def _convert_compound_to_nested(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: +def _convert_compound_to_nested(arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: """Convert compound array to nested structure with field names as keys.""" - result = {} - def _convert_compound_recursive(compound_arr, compound_dtype): """Recursively convert compound array to nested structure.""" nested_data = [] @@ -306,9 +227,7 @@ def _convert_compound_recursive(compound_arr, compound_dtype): row_dict[field_name] = field_data.item() if field_data.size == 1 else field_data.tolist() nested_data.append(row_dict) return nested_data - - result[base_path] = pa.array(_convert_compound_recursive(arr, dset.dtype)) - return result + return pa.array(_convert_compound_recursive(arr, dset.dtype)) # ┌───────────────────────────┐ @@ -324,6 +243,13 @@ def _is_vlen_string_dtype(dtype: np.dtype) -> bool: return False +def _is_vlen_not_string_dtype(dtype: np.dtype) -> bool: + if hasattr(dtype, "metadata") and dtype.metadata and "vlen" in dtype.metadata: + vlen_dtype = dtype.metadata["vlen"] + return vlen_dtype not in (str, bytes) + return False + + def _convert_vlen_string_to_array(arr: np.ndarray) -> pa.Array: list_of_items = [] for item in arr: @@ -342,9 +268,25 @@ def _convert_vlen_string_to_array(arr: np.ndarray) -> pa.Array: # └───────────┘ -def _infer_feature_from_dataset(dset: "h5py.Dataset"): - # non-string varlen - if hasattr(dset.dtype, "metadata") and dset.dtype.metadata and "vlen" in dset.dtype.metadata: +def _recursive_infer_features(h5_obj) -> Features: + import h5py + + features_dict = {} + for path, dset in h5_obj.items(): + if isinstance(dset, h5py.Group): + features_dict[path] = _recursive_infer_features(dset) + elif isinstance(dset, h5py.Dataset): + features_dict[path] = _infer_feature(dset) + return Features(features_dict) + +def _infer_feature(dset: "h5py.Dataset"): + if _is_complex_dtype(dset.dtype): + return _create_complex_features(dset) + elif _is_compound_dtype(dset.dtype): + return _create_compound_features(dset) + elif _is_vlen_string_dtype(dset.dtype): + return Value("string") + elif _is_vlen_not_string_dtype(dset.dtype): vlen_dtype = dset.dtype.metadata["vlen"] inner_feature = _np_to_pa_to_hf_value(vlen_dtype) return List(inner_feature) @@ -370,6 +312,61 @@ def _infer_feature_from_dataset(dset: "h5py.Dataset"): raise TypeError(f"Array{rank}D not supported. Maximum 5 dimensions allowed.") +def _load_array(dset: "h5py.Dataset", path: str, start: int, end: int) -> Dict[str, any]: + arr = dset[start:end] + + if _is_vlen_string_dtype(dset.dtype): + logger.debug( + f"Converting variable-length string data for '{path}' (shape: {arr.shape})" + ) + return _convert_vlen_string_to_array(arr) + elif _is_vlen_not_string_dtype(dset.dtype): + return datasets.features.features.numpy_to_pyarrow_listarray(arr) + elif _is_complex_dtype(dset.dtype): + return _convert_complex_to_nested(arr) + elif _is_compound_dtype(dset.dtype): + return _convert_compound_to_nested(arr, dset) + elif dset.dtype.kind == "O": + raise ValueError( + f"Object dtype dataset '{path}' is not supported. " + f"For variable-length data, please use h5py.vlen_dtype() " + f"when creating the HDF5 file. " + f"See: https://docs.h5py.org/en/stable/special.html#variable-length-strings" + ) + else: + # If any non-batch dimension is zero, emit an unsized pa.list_ + # to avoid creating FixedSizeListArray with list_size=0. + if any(dim == 0 for dim in dset.shape[1:]): + inner_type = pa.from_numpy_dtype(dset.dtype) + return pa.array([[] for _ in arr], type=pa.list_(inner_type)) + else: + return datasets.features.features.numpy_to_pyarrow_listarray(arr) + +def _recursive_load_data(h5_obj, features: Features, start: int, end: int): + import h5py + + batch_dict = {} + for path, dset in h5_obj.items(): + if path not in features: + print(f"skipping {path} not in features: {features}") + continue + else: + print(f"checked {path} in features: {features}") + if isinstance(dset, h5py.Group): + batch_dict[path] = _recursive_load_data(dset, features[path], start, end) + elif isinstance(dset, h5py.Dataset): + batch_dict[path] = _load_array(dset, path, start, end) + + if isinstance(h5_obj, h5py.File): + return pa.Table.from_pydict(batch_dict) + else: + return pa.StructArray.from_arrays(batch_dict.values(), names=batch_dict.keys()) + + +# ┌─────────────┐ +# │ Utilities │ +# └─────────────┘ + def _has_zero_dimensions(feature): if isinstance(feature, _ArrayXD): return any(dim == 0 for dim in feature.shape) @@ -387,3 +384,30 @@ def _sized_arrayxd(rank: int): def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value: return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype))) + +def _first_nongroup_dataset(h5_obj, features: Features, prefix="") -> str: + import h5py + + for path, dset in h5_obj.items(): + if path not in features: + continue + if isinstance(dset, h5py.Group): + return _first_nongroup_dataset(dset, features[path], prefix=f"{path}/") + elif isinstance(dset, h5py.Dataset): + return f"{prefix}{path}" + +def _check_dataset_lengths(h5_obj, features: Features) -> int: + import h5py + + first_path = _first_nongroup_dataset(h5_obj, features) + if first_path is None: + return None + + num_rows = h5_obj[first_path].shape[0] + for path, dset in h5_obj.items(): + if path not in features: + continue + if isinstance(dset, h5py.Dataset): + if dset.shape[0] != num_rows: + raise ValueError(f"Dataset '{path}' has length {dset.shape[0]} but expected {num_rows}") + return num_rows \ No newline at end of file diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 06329a9c430..9a9b2c56f94 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -301,16 +301,20 @@ def test_hdf5_nested_groups(hdf5_file_with_groups): assert len(tables) == 1 _, table = tables[0] - expected_columns = {"root_data", "group1/group_data", "group1/subgroup/sub_data"} + + expected_columns = {"root_data", "group1"} assert set(table.column_names) == expected_columns + group1_columns_expected = {"group_data", "subgroup"} + assert set(fi.name for fi in table["group1"].type.fields) == group1_columns_expected + # Check data root_data = table["root_data"].to_pylist() + group_data = table["group1"].to_pylist() assert root_data == [0, 1, 2] - - group_data = table["group1/group_data"].to_pylist() - expected_group_data = [0.0, 1.0, 2.0] - np.testing.assert_allclose(group_data, expected_group_data, rtol=1e-6) + assert group_data == [{'group_data': 0.0, 'subgroup': {'sub_data': 0}}, + {'group_data': 1.0, 'subgroup': {'sub_data': 1}}, + {'group_data': 2.0, 'subgroup': {'sub_data': 2}}] def test_hdf5_multi_dimensional_arrays(hdf5_file_with_arrays): @@ -432,18 +436,19 @@ def test_hdf5_batch_processing(hdf5_file): def test_hdf5_column_filtering(hdf5_file_with_groups): """Test HDF5 loading with column filtering.""" - config = HDF5Config(columns=["root_data", "group1/group_data"]) - hdf5 = HDF5() - hdf5.config = config + # config = HDF5Config() + hdf5 = HDF5(features=Features({"root_data": Value("int32"), "group1": Features({"group_data": Value("float32")})})) + # hdf5.config = config generator = hdf5._generate_tables([[hdf5_file_with_groups]]) tables = list(generator) assert len(tables) == 1 _, table = tables[0] - expected_columns = {"root_data", "group1/group_data"} + expected_columns = {"root_data", "group1"} assert set(table.column_names) == expected_columns - assert "group1/subgroup/sub_data" not in table.column_names + expected_group1_columns = {"group_data"} + assert set(fi.name for fi in table["group1"].type.fields) == expected_group1_columns def test_hdf5_feature_specification(hdf5_file): @@ -470,7 +475,7 @@ def test_hdf5_mismatched_lengths_error(hdf5_file_with_mismatched_lengths): hdf5 = HDF5() generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) - with pytest.raises(ValueError, match="length.*differs from"): + with pytest.raises(ValueError, match="length.*but expected"): for _ in generator: pass @@ -775,9 +780,10 @@ def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types): def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched_lengths): """Test that mismatched dataset lengths are ignored when the mismatched dataset is excluded via columns config.""" - config = HDF5Config(columns=["data1"]) - hdf5 = HDF5() - hdf5.config = config + # config = HDF5Config(columns=["data1"]) + # hdf5 = HDF5(columns=["data1"]) + hdf5 = HDF5(features=Features({"data1": Value("int32")})) + # hdf5.config = config generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) tables = list(generator) @@ -796,8 +802,15 @@ def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched assert data1_values == [0, 1, 2, 3, 4] # Test 2: Include multiple compatible datasets (all with 5 rows) - config2 = HDF5Config(columns=["data1", "data3", "data4", "data5", "data6"]) - hdf5.config = config2 + # config2 = HDF5Config(columns=["data1", "data3", "data4", "data5", "data6"]) + hdf5 = HDF5(features=Features({ + "data1": Value("int32"), + "data3": Array2D(shape=(3, 4), dtype="float32"), + "data4": Value("float64"), + "data5": Value("bool"), + "data6": Value("string"), + })) + # hdf5.config = config2 generator2 = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) tables2 = list(generator2) From 973784110b1762b8a4982f2b0d6d355b5dc94a13 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 15:13:00 -0400 Subject: [PATCH 02/19] refactor --- src/datasets/packaged_modules/hdf5/hdf5.py | 259 ++++++++++----------- tests/packaged_modules/test_hdf5.py | 173 ++++++++++++-- 2 files changed, 276 insertions(+), 156 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 4d89a8a29df..66d6920ba0d 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -1,5 +1,5 @@ import itertools -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, Optional from typing import List as ListT @@ -119,50 +119,38 @@ def _generate_tables(self, files): def _is_complex_dtype(dtype: np.dtype) -> bool: - """Check if dtype is a complex number type.""" - return dtype.kind == "c" + """Check if dtype is a complex number or array of complex numbers.""" + if dtype.kind == "c": + return True + if dtype.subdtype is not None: + return _is_complex_dtype(dtype.subdtype[0]) + return False -def _create_complex_features(dset: "h5py.Dataset") -> Dict[str, Features]: +def _create_complex_features(dset) -> Dict[str, Features]: """Create Features for complex data with real and imaginary parts `real` and `imag`. NOTE: Always uses float64 for the real and imaginary parts. """ - logger.info( - f"Complex dataset '{dset.name}' (dtype: {dset.dtype}) represented as nested structure with 'real' and 'imag' fields" - ) - - dset_shape = dset.shape[1:] - rank = len(dset_shape) - - if rank == 0: - return Features({"real": Value("float64"), "imag": Value("float64")}) - elif rank == 1: - return Features({ - "real": List(Value("float64"), length=dset_shape[0]), - "imag": List(Value("float64"), length=dset_shape[0]) - }) - elif rank <= 5: - array_feature = _sized_arrayxd(rank) - return Features({ - "real": array_feature(shape=dset_shape, dtype="float64"), - "imag": array_feature(shape=dset_shape, dtype="float64") - }) + if dset.dtype.subdtype is not None: + data_shape = dset.dtype.subdtype[1] else: - raise TypeError(f"Complex Array{rank}D not supported. Maximum 5 dimensions allowed.") - + data_shape = dset.shape[1:] -def _convert_complex_to_nested(arr: np.ndarray) -> Dict[str, pa.Array]: - """Convert complex to Features with real and imaginary parts `real` and `imag`.""" - if arr.size > 1: - data = { - "real": datasets.features.features.numpy_to_pyarrow_listarray(arr.real), - "imag": datasets.features.features.numpy_to_pyarrow_listarray(arr.imag) + return Features( + { + "real": _create_sized_feature_impl(data_shape, Value("float64")), + "imag": _create_sized_feature_impl(data_shape, Value("float64")), } - else: - data = {"real": float(arr.item().real), "imag": float(arr.item().imag)} + ) - return pa.StructArray.from_arrays(data.values(), names=data.keys()) + +def _convert_complex_to_nested(arr: np.ndarray) -> pa.StructArray: + data = { + "real": datasets.features.features.numpy_to_pyarrow_listarray(arr.real), + "imag": datasets.features.features.numpy_to_pyarrow_listarray(arr.imag), + } + return pa.StructArray.from_arrays([data["real"], data["imag"]], names=["real", "imag"]) # ┌────────────┐ @@ -172,62 +160,43 @@ def _convert_complex_to_nested(arr: np.ndarray) -> Dict[str, pa.Array]: def _is_compound_dtype(dtype: np.dtype) -> bool: """Check if dtype is a compound/structured type.""" - return dtype.names is not None + return dtype.kind == "V" -class _MockDataset: - def __init__(self, dtype, name): - self.dtype = dtype - self.name = name +@dataclass +class _CompoundGroup: + dset: "h5py.Dataset" + data: np.ndarray = None + def items(self): + for field_name in self.dset.dtype.names: + field_dtype = self.dset.dtype[field_name] + yield field_name, _CompoundField(self.data, field_name, field_dtype) -def _create_compound_features(dset: "h5py.Dataset") -> Dict[str, Features]: - """Create nested features for compound data with field names as keys.""" - field_names = list(dset.dtype.names) - logger.info( - f"Compound dataset '{dset.name}' (dtype: {dset.dtype}) represented as nested Features with fields: {field_names}" - ) - nested_features_dict = {} - for field_name in field_names: - field_dtype = dset.dtype[field_name] +@dataclass +class _CompoundField: + data: Optional[np.ndarray] + name: str + dtype: np.dtype + shape: tuple[int, ...] = field(init=False) - if _is_complex_dtype(field_dtype): - nested_features_dict[field_name] = Features( - { - "real": Value("float64"), - "imag": Value("float64"), - } - ) - elif _is_compound_dtype(field_dtype): - mock_dset = _MockDataset(field_dtype, f"subfield {field_name} of {dset.name}") - nested_features_dict[field_name] = _create_compound_features(mock_dset) - else: - nested_features_dict[field_name] = _np_to_pa_to_hf_value(field_dtype) + def __post_init__(self): + self.shape = (len(self.data) if self.data is not None else 0,) + self.dtype.shape - return Features(nested_features_dict) + def __getitem__(self, key): + return self.data[key][self.name] -def _convert_compound_to_nested(arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: - """Convert compound array to nested structure with field names as keys.""" - def _convert_compound_recursive(compound_arr, compound_dtype): - """Recursively convert compound array to nested structure.""" - nested_data = [] - for row in compound_arr: - row_dict = {} - for field_name in compound_dtype.names: - field_dtype = compound_dtype[field_name] - field_data = row[field_name] +def _create_compound_features(dset) -> Features: + mock_group = _CompoundGroup(dset) + return _recursive_infer_features(mock_group) - if _is_complex_dtype(field_dtype): - row_dict[field_name] = {"real": float(field_data.real), "imag": float(field_data.imag)} - elif _is_compound_dtype(field_dtype): - row_dict[field_name] = _convert_compound_recursive([field_data], field_dtype)[0] - else: - row_dict[field_name] = field_data.item() if field_data.size == 1 else field_data.tolist() - nested_data.append(row_dict) - return nested_data - return pa.array(_convert_compound_recursive(arr, dset.dtype)) + +def _convert_compound_to_nested(arr, dset) -> pa.StructArray: + mock_group = _CompoundGroup(dset, data=arr) + features = _create_compound_features(dset) + return _recursive_load_data(mock_group, features, 0, len(arr)) # ┌───────────────────────────┐ @@ -269,20 +238,20 @@ def _convert_vlen_string_to_array(arr: np.ndarray) -> pa.Array: def _recursive_infer_features(h5_obj) -> Features: - import h5py - features_dict = {} for path, dset in h5_obj.items(): - if isinstance(dset, h5py.Group): + if _is_group(dset): features_dict[path] = _recursive_infer_features(dset) - elif isinstance(dset, h5py.Dataset): + elif _is_dataset(dset): features_dict[path] = _infer_feature(dset) + return Features(features_dict) -def _infer_feature(dset: "h5py.Dataset"): + +def _infer_feature(dset): if _is_complex_dtype(dset.dtype): return _create_complex_features(dset) - elif _is_compound_dtype(dset.dtype): + elif _is_compound_dtype(dset.dtype) or dset.dtype.kind == "V": return _create_compound_features(dset) elif _is_vlen_string_dtype(dset.dtype): return Value("string") @@ -290,35 +259,13 @@ def _infer_feature(dset: "h5py.Dataset"): vlen_dtype = dset.dtype.metadata["vlen"] inner_feature = _np_to_pa_to_hf_value(vlen_dtype) return List(inner_feature) + return _create_sized_feature(dset) - value_feature = _np_to_pa_to_hf_value(dset.dtype) - dtype_str = value_feature.dtype - dset_shape = dset.shape[1:] - if any(dim == 0 for dim in dset_shape): - logger.warning( - f"HDF5 to Arrow: Found a dataset named '{dset.name}' with shape {dset_shape} and dtype {dtype_str} that has a dimension with size 0. Shape information will be lost in the conversion to List({value_feature})." - ) - return List(value_feature) - - rank = len(dset_shape) - if rank == 0: - return value_feature - elif rank == 1: - return List(value_feature, length=dset_shape[0]) - elif rank <= 5: - return _sized_arrayxd(rank)(shape=dset_shape, dtype=dtype_str) - else: - raise TypeError(f"Array{rank}D not supported. Maximum 5 dimensions allowed.") - - -def _load_array(dset: "h5py.Dataset", path: str, start: int, end: int) -> Dict[str, any]: +def _load_array(dset, path: str, start: int, end: int) -> Dict[str, any]: arr = dset[start:end] if _is_vlen_string_dtype(dset.dtype): - logger.debug( - f"Converting variable-length string data for '{path}' (shape: {arr.shape})" - ) return _convert_vlen_string_to_array(arr) elif _is_vlen_not_string_dtype(dset.dtype): return datasets.features.features.numpy_to_pyarrow_listarray(arr) @@ -342,40 +289,51 @@ def _load_array(dset: "h5py.Dataset", path: str, start: int, end: int) -> Dict[s else: return datasets.features.features.numpy_to_pyarrow_listarray(arr) -def _recursive_load_data(h5_obj, features: Features, start: int, end: int): - import h5py +def _recursive_load_data(h5_obj, features: Features, start: int, end: int): batch_dict = {} for path, dset in h5_obj.items(): if path not in features: - print(f"skipping {path} not in features: {features}") continue - else: - print(f"checked {path} in features: {features}") - if isinstance(dset, h5py.Group): + if _is_group(dset): batch_dict[path] = _recursive_load_data(dset, features[path], start, end) - elif isinstance(dset, h5py.Dataset): + elif _is_dataset(dset): batch_dict[path] = _load_array(dset, path, start, end) - if isinstance(h5_obj, h5py.File): + if _is_file(h5_obj): return pa.Table.from_pydict(batch_dict) - else: - return pa.StructArray.from_arrays(batch_dict.values(), names=batch_dict.keys()) + + return pa.StructArray.from_arrays(batch_dict.values(), names=batch_dict.keys()) # ┌─────────────┐ # │ Utilities │ # └─────────────┘ -def _has_zero_dimensions(feature): - if isinstance(feature, _ArrayXD): - return any(dim == 0 for dim in feature.shape) - elif isinstance(feature, List): # also gets regular List - return feature.length == 0 or _has_zero_dimensions(feature.feature) - elif isinstance(feature, LargeList): - return _has_zero_dimensions(feature.feature) + +def _create_sized_feature(dset): + dset_shape = dset.shape[1:] + value_feature = _np_to_pa_to_hf_value(dset.dtype) + return _create_sized_feature_impl(dset_shape, value_feature) + + +def _create_sized_feature_impl(dset_shape, value_feature): + dtype_str = value_feature.dtype + if any(dim == 0 for dim in dset_shape): + logger.warning( + f"HDF5 to Arrow: Found a dataset with shape {dset_shape} and dtype {dtype_str} that has a dimension with size 0. Shape information will be lost in the conversion to List({value_feature})." + ) + return List(value_feature) + + rank = len(dset_shape) + if rank == 0: + return value_feature + elif rank == 1: + return List(value_feature, length=dset_shape[0]) + elif rank <= 5: + return _sized_arrayxd(rank)(shape=dset_shape, dtype=dtype_str) else: - return False + raise TypeError(f"Array{rank}D not supported. Maximum 5 dimensions allowed.") def _sized_arrayxd(rank: int): @@ -385,20 +343,18 @@ def _sized_arrayxd(rank: int): def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value: return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype))) + def _first_nongroup_dataset(h5_obj, features: Features, prefix="") -> str: - import h5py - for path, dset in h5_obj.items(): if path not in features: continue - if isinstance(dset, h5py.Group): + if _is_group(dset): return _first_nongroup_dataset(dset, features[path], prefix=f"{path}/") - elif isinstance(dset, h5py.Dataset): + elif _is_dataset(dset): return f"{prefix}{path}" -def _check_dataset_lengths(h5_obj, features: Features) -> int: - import h5py +def _check_dataset_lengths(h5_obj, features: Features) -> int: first_path = _first_nongroup_dataset(h5_obj, features) if first_path is None: return None @@ -407,7 +363,36 @@ def _check_dataset_lengths(h5_obj, features: Features) -> int: for path, dset in h5_obj.items(): if path not in features: continue - if isinstance(dset, h5py.Dataset): + if _is_dataset(dset): if dset.shape[0] != num_rows: raise ValueError(f"Dataset '{path}' has length {dset.shape[0]} but expected {num_rows}") - return num_rows \ No newline at end of file + return num_rows + + +def _is_group(h5_obj) -> bool: + import h5py + + return isinstance(h5_obj, h5py.Group) or isinstance(h5_obj, _CompoundGroup) + + +def _is_dataset(h5_obj) -> bool: + import h5py + + return isinstance(h5_obj, h5py.Dataset) or isinstance(h5_obj, _CompoundField) + + +def _is_file(h5_obj) -> bool: + import h5py + + return isinstance(h5_obj, h5py.File) + + +def _has_zero_dimensions(feature): + if isinstance(feature, _ArrayXD): + return any(dim == 0 for dim in feature.shape) + elif isinstance(feature, List): # also gets regular List + return feature.length == 0 or _has_zero_dimensions(feature.feature) + elif isinstance(feature, LargeList): + return _has_zero_dimensions(feature.feature) + else: + return False diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 9a9b2c56f94..a1cd149312c 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -183,6 +183,52 @@ def hdf5_file_with_compound_data(tmp_path): return str(filename) +@pytest.fixture +def hdf5_file_with_compound_complex_arrays(tmp_path): + """Create an HDF5 file with compound datasets containing complex arrays.""" + filename = tmp_path / "compound_complex_arrays.h5" + + with h5py.File(filename, "w") as f: + # Compound type with complex arrays + dt_complex_arrays = np.dtype( + [ + ("position", [("x", "i4"), ("y", "i4")]), + ("complex_field", "c8"), + ("complex_array", "c8", (2, 3)), + ("nested_complex", [("real", "f4"), ("imag", "f4")]), + ] + ) + + # Create data with complex numbers + compound_data = np.array( + [ + ( + (1, 2), + 1.0 + 2.0j, + [[1.0 + 2.0j, 3.0 + 4.0j, 5.0 + 6.0j], [7.0 + 8.0j, 9.0 + 10.0j, 11.0 + 12.0j]], + (1.5, 2.5), + ), + ( + (3, 4), + 3.0 + 4.0j, + [[13.0 + 14.0j, 15.0 + 16.0j, 17.0 + 18.0j], [19.0 + 20.0j, 21.0 + 22.0j, 23.0 + 24.0j]], + (3.5, 4.5), + ), + ( + (5, 6), + 5.0 + 6.0j, + [[25.0 + 26.0j, 27.0 + 28.0j, 29.0 + 30.0j], [31.0 + 32.0j, 33.0 + 34.0j, 35.0 + 36.0j]], + (5.5, 6.5), + ), + ], + dtype=dt_complex_arrays, + ) + + f.create_dataset("compound_with_complex", data=compound_data) + + return str(filename) + + @pytest.fixture def hdf5_file_with_mismatched_lengths(tmp_path): """Create an HDF5 file with datasets of different lengths (should raise error).""" @@ -306,15 +352,17 @@ def test_hdf5_nested_groups(hdf5_file_with_groups): assert set(table.column_names) == expected_columns group1_columns_expected = {"group_data", "subgroup"} - assert set(fi.name for fi in table["group1"].type.fields) == group1_columns_expected + assert {fi.name for fi in table["group1"].type.fields} == group1_columns_expected # Check data root_data = table["root_data"].to_pylist() group_data = table["group1"].to_pylist() assert root_data == [0, 1, 2] - assert group_data == [{'group_data': 0.0, 'subgroup': {'sub_data': 0}}, - {'group_data': 1.0, 'subgroup': {'sub_data': 1}}, - {'group_data': 2.0, 'subgroup': {'sub_data': 2}}] + assert group_data == [ + {"group_data": 0.0, "subgroup": {"sub_data": 0}}, + {"group_data": 1.0, "subgroup": {"sub_data": 1}}, + {"group_data": 2.0, "subgroup": {"sub_data": 2}}, + ] def test_hdf5_multi_dimensional_arrays(hdf5_file_with_arrays): @@ -436,9 +484,7 @@ def test_hdf5_batch_processing(hdf5_file): def test_hdf5_column_filtering(hdf5_file_with_groups): """Test HDF5 loading with column filtering.""" - # config = HDF5Config() hdf5 = HDF5(features=Features({"root_data": Value("int32"), "group1": Features({"group_data": Value("float32")})})) - # hdf5.config = config generator = hdf5._generate_tables([[hdf5_file_with_groups]]) tables = list(generator) @@ -448,7 +494,7 @@ def test_hdf5_column_filtering(hdf5_file_with_groups): expected_columns = {"root_data", "group1"} assert set(table.column_names) == expected_columns expected_group1_columns = {"group_data"} - assert set(fi.name for fi in table["group1"].type.fields) == expected_group1_columns + assert {fi.name for fi in table["group1"].type.fields} == expected_group1_columns def test_hdf5_feature_specification(hdf5_file): @@ -780,10 +826,7 @@ def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types): def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched_lengths): """Test that mismatched dataset lengths are ignored when the mismatched dataset is excluded via columns config.""" - # config = HDF5Config(columns=["data1"]) - # hdf5 = HDF5(columns=["data1"]) hdf5 = HDF5(features=Features({"data1": Value("int32")})) - # hdf5.config = config generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) tables = list(generator) @@ -802,15 +845,17 @@ def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched assert data1_values == [0, 1, 2, 3, 4] # Test 2: Include multiple compatible datasets (all with 5 rows) - # config2 = HDF5Config(columns=["data1", "data3", "data4", "data5", "data6"]) - hdf5 = HDF5(features=Features({ - "data1": Value("int32"), - "data3": Array2D(shape=(3, 4), dtype="float32"), - "data4": Value("float64"), - "data5": Value("bool"), - "data6": Value("string"), - })) - # hdf5.config = config2 + hdf5 = HDF5( + features=Features( + { + "data1": Value("int32"), + "data3": Array2D(shape=(3, 4), dtype="float32"), + "data4": Value("float64"), + "data5": Value("bool"), + "data6": Value("string"), + } + ) + ) generator2 = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) tables2 = list(generator2) @@ -838,3 +883,93 @@ def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched "tiny", "another string", ] # vlen string + + +def test_hdf5_compound_with_complex_arrays(hdf5_file_with_compound_complex_arrays): + """Test HDF5 loading with compound datasets containing complex arrays.""" + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + + generator = hdf5._generate_tables([[hdf5_file_with_compound_complex_arrays]]) + tables = list(generator) + + assert len(tables) == 1 + _, table = tables[0] + + # Check that compound types with complex arrays are represented as nested structures + expected_columns = {"compound_with_complex"} + assert set(table.column_names) == expected_columns + + # Check compound data with complex arrays + compound_data = table["compound_with_complex"].to_pylist() + assert len(compound_data) == 3 + + # Check first row + first_row = compound_data[0] + assert first_row["position"]["x"] == 1 + assert first_row["position"]["y"] == 2 + + # Check complex field (should be represented as real/imag structure) + assert first_row["complex_field"]["real"] == 1.0 + assert first_row["complex_field"]["imag"] == 2.0 + + # Check complex array (should be represented as nested real/imag structures) + complex_array = first_row["complex_array"] + assert len(complex_array["real"]) == 2 # 2 rows + assert len(complex_array["real"][0]) == 3 # 3 columns + + # Check first element of complex array + assert complex_array["real"][0][0] == 1.0 + assert complex_array["imag"][0][0] == 2.0 + + # Check nested complex field + assert first_row["nested_complex"]["real"] == 1.5 + assert first_row["nested_complex"]["imag"] == 2.5 + + +def test_hdf5_feature_inference_compound_complex_arrays(hdf5_file_with_compound_complex_arrays): + """Test automatic feature inference for compound datasets with complex arrays.""" + config = HDF5Config() + hdf5 = HDF5() + hdf5.config = config + hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_compound_complex_arrays]}) + + # Trigger feature inference + dl_manager = StreamingDownloadManager() + hdf5._split_generators(dl_manager) + + # Check that features were inferred correctly + assert hdf5.info.features is not None + features = hdf5.info.features + + # Check compound type features with complex arrays + assert "compound_with_complex" in features + assert isinstance(features["compound_with_complex"], Features) + + # Check nested structure + compound_features = features["compound_with_complex"] + assert "position" in compound_features + assert "complex_field" in compound_features + assert "complex_array" in compound_features + assert "nested_complex" in compound_features + + # Check position field (nested compound) + assert isinstance(compound_features["position"], Features) + assert compound_features["position"]["x"] == Value("int32") + assert compound_features["position"]["y"] == Value("int32") + + # Check complex field (should be real/imag structure) + assert isinstance(compound_features["complex_field"], Features) + assert compound_features["complex_field"]["real"] == Value("float64") + assert compound_features["complex_field"]["imag"] == Value("float64") + + # Check complex array (should be nested real/imag structures) + assert isinstance(compound_features["complex_array"], Features) + assert compound_features["complex_array"]["real"] == Array2D(shape=(2, 3), dtype="float64") + assert compound_features["complex_array"]["imag"] == Array2D(shape=(2, 3), dtype="float64") + + # Check nested complex field + assert isinstance(compound_features["nested_complex"], Features) + assert compound_features["nested_complex"]["real"] == Value("float32") + assert compound_features["nested_complex"]["imag"] == Value("float32") From dfae20ad1d05637b9f06a7231bf25a7fca104893 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 15:30:07 -0400 Subject: [PATCH 03/19] simplify vlen --- src/datasets/packaged_modules/hdf5/hdf5.py | 43 +++++++--------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 66d6920ba0d..279114ff012 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -204,32 +204,23 @@ def _convert_compound_to_nested(arr, dset) -> pa.StructArray: # └───────────────────────────┘ -def _is_vlen_string_dtype(dtype: np.dtype) -> bool: +def _is_vlen_dtype(dtype: np.dtype) -> bool: """Check if dtype is a variable-length string type.""" if hasattr(dtype, "metadata") and dtype.metadata and "vlen" in dtype.metadata: - vlen_dtype = dtype.metadata["vlen"] - return vlen_dtype in (str, bytes) + return True return False -def _is_vlen_not_string_dtype(dtype: np.dtype) -> bool: - if hasattr(dtype, "metadata") and dtype.metadata and "vlen" in dtype.metadata: - vlen_dtype = dtype.metadata["vlen"] - return vlen_dtype not in (str, bytes) - return False +def _create_vlen_features(dset) -> Features: + vlen_dtype = dset.dtype.metadata["vlen"] + if vlen_dtype in (str, bytes): + return Value("string") + inner_feature = _np_to_pa_to_hf_value(vlen_dtype) + return List(inner_feature) -def _convert_vlen_string_to_array(arr: np.ndarray) -> pa.Array: - list_of_items = [] - for item in arr: - if isinstance(item, bytes): - logger.info("Assuming variable-length bytes are utf-8 encoded strings") - list_of_items.append(item.decode("utf-8")) - elif isinstance(item, str): - list_of_items.append(item) - else: - raise ValueError(f"Unsupported variable-length string type: {type(item)}") - return pa.array(list_of_items) +def _convert_vlen_to_array(arr: np.ndarray) -> pa.Array: + return datasets.features.features.numpy_to_pyarrow_listarray(arr) # ┌───────────┐ @@ -253,22 +244,16 @@ def _infer_feature(dset): return _create_complex_features(dset) elif _is_compound_dtype(dset.dtype) or dset.dtype.kind == "V": return _create_compound_features(dset) - elif _is_vlen_string_dtype(dset.dtype): - return Value("string") - elif _is_vlen_not_string_dtype(dset.dtype): - vlen_dtype = dset.dtype.metadata["vlen"] - inner_feature = _np_to_pa_to_hf_value(vlen_dtype) - return List(inner_feature) + elif _is_vlen_dtype(dset.dtype): + return _create_vlen_features(dset) return _create_sized_feature(dset) def _load_array(dset, path: str, start: int, end: int) -> Dict[str, any]: arr = dset[start:end] - if _is_vlen_string_dtype(dset.dtype): - return _convert_vlen_string_to_array(arr) - elif _is_vlen_not_string_dtype(dset.dtype): - return datasets.features.features.numpy_to_pyarrow_listarray(arr) + if _is_vlen_dtype(dset.dtype): + return _convert_vlen_to_array(arr) elif _is_complex_dtype(dset.dtype): return _convert_complex_to_nested(arr) elif _is_compound_dtype(dset.dtype): From ed3819f6b5bc7560b230fc462c04e2128300b770 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 15:32:35 -0400 Subject: [PATCH 04/19] rename --- src/datasets/packaged_modules/hdf5/hdf5.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 279114ff012..d3ff9825540 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -106,7 +106,7 @@ def _generate_tables(self, files): effective_batch = batch_size_cfg or self._writer_batch_size or num_rows for start in range(0, num_rows, effective_batch): end = min(start + effective_batch, num_rows) - pa_table = _recursive_load_data(h5, self.info.features, start, end) + pa_table = _recursive_load_arrays(h5, self.info.features, start, end) yield f"{file_idx}_{start}", self._cast_table(pa_table) except ValueError as e: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") @@ -196,7 +196,7 @@ def _create_compound_features(dset) -> Features: def _convert_compound_to_nested(arr, dset) -> pa.StructArray: mock_group = _CompoundGroup(dset, data=arr) features = _create_compound_features(dset) - return _recursive_load_data(mock_group, features, 0, len(arr)) + return _recursive_load_arrays(mock_group, features, 0, len(arr)) # ┌───────────────────────────┐ @@ -275,13 +275,13 @@ def _load_array(dset, path: str, start: int, end: int) -> Dict[str, any]: return datasets.features.features.numpy_to_pyarrow_listarray(arr) -def _recursive_load_data(h5_obj, features: Features, start: int, end: int): +def _recursive_load_arrays(h5_obj, features: Features, start: int, end: int): batch_dict = {} for path, dset in h5_obj.items(): if path not in features: continue if _is_group(dset): - batch_dict[path] = _recursive_load_data(dset, features[path], start, end) + batch_dict[path] = _recursive_load_arrays(dset, features[path], start, end) elif _is_dataset(dset): batch_dict[path] = _load_array(dset, path, start, end) From 1c83dd2ac7dee2f7877affdf7c919e02a01a831a Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 15:35:55 -0400 Subject: [PATCH 05/19] rm empty __post_init__ --- src/datasets/packaged_modules/hdf5/hdf5.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index d3ff9825540..4b26890d976 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -38,9 +38,6 @@ class HDF5Config(datasets.BuilderConfig): columns: Optional[ListT[str]] = None features: Optional[datasets.Features] = None - def __post_init__(self): - super().__post_init__() - class HDF5(datasets.ArrowBasedBuilder): """ArrowBasedBuilder that converts HDF5 files to Arrow tables using the HF extension types.""" From 89c1970e9141a0e15ed2a111defe2436356e37fc Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 15:36:39 -0400 Subject: [PATCH 06/19] always cast --- src/datasets/packaged_modules/hdf5/hdf5.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 4b26890d976..fb3fd28487b 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -82,11 +82,6 @@ def _split_generators(self, dl_manager): ) return splits - def _cast_table(self, pa_table: pa.Table) -> pa.Table: - if self.info.features is not None: - pa_table = table_cast(pa_table, self.info.features.arrow_schema) - return pa_table - def _generate_tables(self, files): import h5py @@ -104,7 +99,7 @@ def _generate_tables(self, files): for start in range(0, num_rows, effective_batch): end = min(start + effective_batch, num_rows) pa_table = _recursive_load_arrays(h5, self.info.features, start, end) - yield f"{file_idx}_{start}", self._cast_table(pa_table) + yield f"{file_idx}_{start}", table_cast(pa_table, self.info.features.arrow_schema) except ValueError as e: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise From 3b7c700e8bc061dc4457d5822712a2b35323dc0c Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 15:49:16 -0400 Subject: [PATCH 07/19] don't always cast complex parts to float64 --- src/datasets/packaged_modules/hdf5/hdf5.py | 33 +++++++++++++--------- tests/packaged_modules/test_hdf5.py | 17 +++++++---- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index fb3fd28487b..784e53b13d4 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -89,6 +89,7 @@ def _generate_tables(self, files): for file_idx, file in enumerate(itertools.chain.from_iterable(files)): try: with h5py.File(file, "r") as h5: + # Infer features and lengths from first file if self.info.features is None: self.info.features = _recursive_infer_features(h5) num_rows = _check_dataset_lengths(h5, self.info.features) @@ -111,7 +112,6 @@ def _generate_tables(self, files): def _is_complex_dtype(dtype: np.dtype) -> bool: - """Check if dtype is a complex number or array of complex numbers.""" if dtype.kind == "c": return True if dtype.subdtype is not None: @@ -119,20 +119,27 @@ def _is_complex_dtype(dtype: np.dtype) -> bool: return False -def _create_complex_features(dset) -> Dict[str, Features]: - """Create Features for complex data with real and imaginary parts `real` and `imag`. - - NOTE: Always uses float64 for the real and imaginary parts. - """ +def _create_complex_features(dset) -> Features: if dset.dtype.subdtype is not None: - data_shape = dset.dtype.subdtype[1] + dtype, data_shape = dset.dtype.subdtype else: data_shape = dset.shape[1:] + dtype = dset.dtype + + if dtype == np.complex64: + # two float32s + value_type = Value("float32") + elif dtype == np.complex128: + # two float64s + value_type = Value("float64") + else: + logger.warning(f"Found complex dtype {dtype} that is not supported. Converting to float64...") + value_type = Value("float64") return Features( { - "real": _create_sized_feature_impl(data_shape, Value("float64")), - "imag": _create_sized_feature_impl(data_shape, Value("float64")), + "real": _create_sized_feature_impl(data_shape, value_type), + "imag": _create_sized_feature_impl(data_shape, value_type), } ) @@ -151,7 +158,6 @@ def _convert_complex_to_nested(arr: np.ndarray) -> pa.StructArray: def _is_compound_dtype(dtype: np.dtype) -> bool: - """Check if dtype is a compound/structured type.""" return dtype.kind == "V" @@ -197,7 +203,6 @@ def _convert_compound_to_nested(arr, dset) -> pa.StructArray: def _is_vlen_dtype(dtype: np.dtype) -> bool: - """Check if dtype is a variable-length string type.""" if hasattr(dtype, "metadata") and dtype.metadata and "vlen" in dtype.metadata: return True return False @@ -241,7 +246,7 @@ def _infer_feature(dset): return _create_sized_feature(dset) -def _load_array(dset, path: str, start: int, end: int) -> Dict[str, any]: +def _load_array(dset, path: str, start: int, end: int) -> pa.Array: arr = dset[start:end] if _is_vlen_dtype(dset.dtype): @@ -321,7 +326,7 @@ def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value: return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype))) -def _first_nongroup_dataset(h5_obj, features: Features, prefix="") -> str: +def _first_nongroup_dataset(h5_obj, features: Features, prefix=""): for path, dset in h5_obj.items(): if path not in features: continue @@ -367,7 +372,7 @@ def _is_file(h5_obj) -> bool: def _has_zero_dimensions(feature): if isinstance(feature, _ArrayXD): return any(dim == 0 for dim in feature.shape) - elif isinstance(feature, List): # also gets regular List + elif isinstance(feature, List): return feature.length == 0 or _has_zero_dimensions(feature.feature) elif isinstance(feature, LargeList): return _has_zero_dimensions(feature.feature) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index a1cd149312c..19af40bdd6f 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -724,6 +724,11 @@ def test_hdf5_complex_numbers(hdf5_file_with_complex_data): assert complex_64_data[2] == {"real": 5.0, "imag": 6.0} assert complex_64_data[3] == {"real": 7.0, "imag": 8.0} + assert hdf5.info.features["complex_64"]["real"].dtype == "float32" + assert hdf5.info.features["complex_64"]["imag"].dtype == "float32" + assert hdf5.info.features["complex_128"]["real"].dtype == "float64" + assert hdf5.info.features["complex_128"]["imag"].dtype == "float64" + def test_hdf5_compound_types(hdf5_file_with_compound_data): """Test HDF5 loading with compound/structured datasets.""" @@ -771,8 +776,8 @@ def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data): # Check complex number features assert "complex_64" in features assert isinstance(features["complex_64"], Features) - assert features["complex_64"]["real"] == Value("float64") - assert features["complex_64"]["imag"] == Value("float64") + assert features["complex_64"]["real"] == Value("float32") + assert features["complex_64"]["imag"] == Value("float32") def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data): @@ -961,13 +966,13 @@ def test_hdf5_feature_inference_compound_complex_arrays(hdf5_file_with_compound_ # Check complex field (should be real/imag structure) assert isinstance(compound_features["complex_field"], Features) - assert compound_features["complex_field"]["real"] == Value("float64") - assert compound_features["complex_field"]["imag"] == Value("float64") + assert compound_features["complex_field"]["real"] == Value("float32") + assert compound_features["complex_field"]["imag"] == Value("float32") # Check complex array (should be nested real/imag structures) assert isinstance(compound_features["complex_array"], Features) - assert compound_features["complex_array"]["real"] == Array2D(shape=(2, 3), dtype="float64") - assert compound_features["complex_array"]["imag"] == Array2D(shape=(2, 3), dtype="float64") + assert compound_features["complex_array"]["real"] == Array2D(shape=(2, 3), dtype="float32") + assert compound_features["complex_array"]["imag"] == Array2D(shape=(2, 3), dtype="float32") # Check nested complex field assert isinstance(compound_features["nested_complex"], Features) From 612a10bfceca90548f45f91f6958af859cdc0131 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 16:27:24 -0400 Subject: [PATCH 08/19] format --- src/datasets/packaged_modules/hdf5/hdf5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 784e53b13d4..1014f8b9bb7 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Optional from typing import List as ListT import numpy as np From 66e5b56aefa1ca02ca8ac180cd87b367d183bdaa Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 16:30:46 -0400 Subject: [PATCH 09/19] rm hasattr metadata --- src/datasets/packaged_modules/hdf5/hdf5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 1014f8b9bb7..4ec285a9a08 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -203,7 +203,7 @@ def _convert_compound_to_nested(arr, dset) -> pa.StructArray: def _is_vlen_dtype(dtype: np.dtype) -> bool: - if hasattr(dtype, "metadata") and dtype.metadata and "vlen" in dtype.metadata: + if dtype.metadata and "vlen" in dtype.metadata: return True return False From 80b3017be422cb0a993ff3a68f89cf303f2b3f3c Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 17:49:42 -0400 Subject: [PATCH 10/19] update tests --- src/datasets/packaged_modules/hdf5/hdf5.py | 39 +- tests/packaged_modules/test_hdf5.py | 424 +++++++-------------- 2 files changed, 152 insertions(+), 311 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 4ec285a9a08..89acf9c1aaf 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -1,7 +1,6 @@ import itertools from dataclasses import dataclass, field from typing import TYPE_CHECKING, Optional -from typing import List as ListT import numpy as np import pyarrow as pa @@ -19,7 +18,7 @@ _ArrayXD, _arrow_to_datasets_dtype, ) -from datasets.table import table_cast +from datasets.table import cast_table_to_features if TYPE_CHECKING: @@ -35,7 +34,6 @@ class HDF5Config(datasets.BuilderConfig): """BuilderConfig for HDF5.""" batch_size: Optional[int] = None - columns: Optional[ListT[str]] = None features: Optional[datasets.Features] = None @@ -45,15 +43,6 @@ class HDF5(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = HDF5Config def _info(self): - if ( - self.config.columns is not None - and self.config.features is not None - and set(self.config.columns) != set(self.config.features) - ): - raise ValueError( - "The columns and features argument must contain the same columns, but got ", - f"{self.config.columns} and {self.config.features}", - ) return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): @@ -76,10 +65,6 @@ def _split_generators(self, dl_manager): self.info.features = _recursive_infer_features(h5) break splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) - if self.config.columns is not None and set(self.config.columns) != set(self.info.features): - self.info.features = datasets.Features( - {col: feat for col, feat in self.info.features.items() if col in self.config.columns} - ) return splits def _generate_tables(self, files): @@ -100,7 +85,10 @@ def _generate_tables(self, files): for start in range(0, num_rows, effective_batch): end = min(start + effective_batch, num_rows) pa_table = _recursive_load_arrays(h5, self.info.features, start, end) - yield f"{file_idx}_{start}", table_cast(pa_table, self.info.features.arrow_schema) + if pa_table is None: + logger.warning(f"File {file} contains no data, skipping...") + continue + yield f"{file_idx}_{start}", cast_table_to_features(pa_table, self.info.features) except ValueError as e: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise @@ -278,14 +266,21 @@ def _recursive_load_arrays(h5_obj, features: Features, start: int, end: int): if path not in features: continue if _is_group(dset): - batch_dict[path] = _recursive_load_arrays(dset, features[path], start, end) + arr = _recursive_load_arrays(dset, features[path], start, end) elif _is_dataset(dset): - batch_dict[path] = _load_array(dset, path, start, end) + arr = _load_array(dset, path, start, end) + else: + raise ValueError(f"Unexpected type {type(dset)}") + + if arr is not None: + batch_dict[path] = arr if _is_file(h5_obj): return pa.Table.from_pydict(batch_dict) - return pa.StructArray.from_arrays(batch_dict.values(), names=batch_dict.keys()) + if batch_dict: + keys, values = zip(*batch_dict.items()) + return pa.StructArray.from_arrays(values, names=keys) # ┌─────────────┐ @@ -331,7 +326,9 @@ def _first_nongroup_dataset(h5_obj, features: Features, prefix=""): if path not in features: continue if _is_group(dset): - return _first_nongroup_dataset(dset, features[path], prefix=f"{path}/") + found = _first_nongroup_dataset(dset, features[path], prefix=f"{path}/") + if found is not None: + return found elif _is_dataset(dset): return f"{prefix}{path}" diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 19af40bdd6f..8e28baba9f4 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -4,8 +4,8 @@ from datasets import Array2D, Array3D, Array4D, Features, List, Value, load_dataset from datasets.builder import InvalidConfigName -from datasets.data_files import DataFilesDict, DataFilesList -from datasets.download.streaming_download_manager import StreamingDownloadManager +from datasets.data_files import DataFilesList +from datasets.exceptions import DatasetGenerationError from datasets.packaged_modules.hdf5.hdf5 import HDF5, HDF5Config @@ -318,47 +318,34 @@ def test_config_raises_when_invalid_data_files(data_files): def test_hdf5_basic_functionality(hdf5_file): """Test basic HDF5 loading with simple numeric datasets.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file]]) - - tables = list(generator) - assert len(tables) == 1 + dataset = load_dataset("hdf5", data_files=[hdf5_file], split="train") - _, table = tables[0] - assert "int32" in table.column_names - assert "float32" in table.column_names - assert "bool" in table.column_names + assert "int32" in dataset.column_names + assert "float32" in dataset.column_names + assert "bool" in dataset.column_names - # Check data - int32_data = table["int32"].to_pylist() - assert int32_data == [0, 1, 2, 3, 4] + assert np.asarray(dataset.data["int32"]).dtype == np.int32 + assert np.asarray(dataset.data["float32"]).dtype == np.float32 + assert np.asarray(dataset.data["bool"]).dtype == np.bool_ - float32_data = table["float32"].to_pylist() + assert dataset["int32"] == [0, 1, 2, 3, 4] + float32_data = dataset["float32"] expected_float32 = [0.0, 0.1, 0.2, 0.3, 0.4] np.testing.assert_allclose(float32_data, expected_float32, rtol=1e-6) def test_hdf5_nested_groups(hdf5_file_with_groups): """Test HDF5 loading with nested groups.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_groups]]) - - tables = list(generator) - assert len(tables) == 1 - - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_groups], split="train") expected_columns = {"root_data", "group1"} - assert set(table.column_names) == expected_columns - - group1_columns_expected = {"group_data", "subgroup"} - assert {fi.name for fi in table["group1"].type.fields} == group1_columns_expected + assert set(dataset.column_names) == expected_columns # Check data - root_data = table["root_data"].to_pylist() - group_data = table["group1"].to_pylist() + root_data = dataset["root_data"] + group1_data = dataset["group1"] assert root_data == [0, 1, 2] - assert group_data == [ + assert group1_data == [ {"group_data": 0.0, "subgroup": {"sub_data": 0}}, {"group_data": 1.0, "subgroup": {"sub_data": 1}}, {"group_data": 2.0, "subgroup": {"sub_data": 2}}, @@ -367,18 +354,13 @@ def test_hdf5_nested_groups(hdf5_file_with_groups): def test_hdf5_multi_dimensional_arrays(hdf5_file_with_arrays): """Test HDF5 loading with multi-dimensional arrays.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_arrays]]) - - tables = list(generator) - assert len(tables) == 1 + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_arrays], split="train") - _, table = tables[0] expected_columns = {"matrix_2d", "tensor_3d", "tensor_4d", "tensor_5d", "vector_1d"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check shapes - matrix_2d = table["matrix_2d"].to_pylist() + matrix_2d = dataset["matrix_2d"] assert len(matrix_2d) == 4 # 4 rows assert len(matrix_2d[0]) == 3 # 3 rows in each matrix assert len(matrix_2d[0][0]) == 4 # 4 columns in each matrix @@ -386,18 +368,13 @@ def test_hdf5_multi_dimensional_arrays(hdf5_file_with_arrays): def test_hdf5_vlen_arrays(hdf5_file_with_vlen_arrays): """Test HDF5 loading with variable-length arrays (int32).""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_vlen_arrays]]) - - tables = list(generator) - assert len(tables) == 1 + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_vlen_arrays], split="train") - _, table = tables[0] expected_columns = {"vlen_ints", "mixed_data"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check vlen_ints data - vlen_ints = table["vlen_ints"].to_pylist() + vlen_ints = dataset["vlen_ints"] assert len(vlen_ints) == 4 assert vlen_ints[0] == [1, 2, 3] assert vlen_ints[1] == [4, 5] @@ -405,7 +382,7 @@ def test_hdf5_vlen_arrays(hdf5_file_with_vlen_arrays): assert vlen_ints[3] == [10] # Check mixed_data (with None values) - mixed_data = table["mixed_data"].to_pylist() + mixed_data = dataset["mixed_data"] assert len(mixed_data) == 4 assert mixed_data[0] == [1, 2, 3] assert mixed_data[1] == [] # Empty array instead of None @@ -415,18 +392,12 @@ def test_hdf5_vlen_arrays(hdf5_file_with_vlen_arrays): def test_hdf5_variable_length_strings(hdf5_file_with_variable_length_strings): """Test HDF5 loading with variable-length string datasets.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_variable_length_strings]]) - - tables = list(generator) - assert len(tables) == 1 - - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_variable_length_strings], split="train") expected_columns = {"var_strings", "var_bytes"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check variable-length strings (converted to strings for usability) - var_strings = table["var_strings"].to_pylist() + var_strings = dataset["var_strings"] assert len(var_strings) == 4 assert var_strings[0] == "short" assert var_strings[1] == "medium length string" @@ -434,7 +405,7 @@ def test_hdf5_variable_length_strings(hdf5_file_with_variable_length_strings): assert var_strings[3] == "tiny" # Check variable-length bytes (converted to strings for usability) - var_bytes = table["var_bytes"].to_pylist() + var_bytes = dataset["var_bytes"] assert len(var_bytes) == 4 assert var_bytes[0] == "short" assert var_bytes[1] == "medium length bytes" @@ -444,21 +415,15 @@ def test_hdf5_variable_length_strings(hdf5_file_with_variable_length_strings): def test_hdf5_different_dtypes(hdf5_file_with_different_dtypes): """Test HDF5 loading with various numeric dtypes.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_different_dtypes]]) - - tables = list(generator) - assert len(tables) == 1 - - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_different_dtypes], split="train") expected_columns = {"int8", "int16", "int64", "uint8", "uint16", "uint32", "uint64", "float16", "float64", "bytes"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check specific dtypes - int8_data = table["int8"].to_pylist() + int8_data = dataset["int8"] assert int8_data == [0, 1, 2] - bytes_data = table["bytes"].to_pylist() + bytes_data = dataset["bytes"] assert bytes_data == [b"row_0", b"row_1", b"row_2"] @@ -484,81 +449,59 @@ def test_hdf5_batch_processing(hdf5_file): def test_hdf5_column_filtering(hdf5_file_with_groups): """Test HDF5 loading with column filtering.""" - hdf5 = HDF5(features=Features({"root_data": Value("int32"), "group1": Features({"group_data": Value("float32")})})) - generator = hdf5._generate_tables([[hdf5_file_with_groups]]) + features = Features({"root_data": Value("int32"), "group1": Features({"group_data": Value("float32")})}) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_groups], split="train", features=features) - tables = list(generator) - assert len(tables) == 1 - - _, table = tables[0] expected_columns = {"root_data", "group1"} - assert set(table.column_names) == expected_columns - expected_group1_columns = {"group_data"} - assert {fi.name for fi in table["group1"].type.fields} == expected_group1_columns + assert set(dataset.column_names) == expected_columns + + # Check that subgroup is filtered out + group1_data = dataset["group1"] + assert group1_data == [ + {"group_data": 0.0}, + {"group_data": 1.0}, + {"group_data": 2.0}, + ] def test_hdf5_feature_specification(hdf5_file): """Test HDF5 loading with explicit feature specification.""" - features = Features({"int32": Value("int32"), "float32": Value("float32"), "bool": Value("bool")}) - - config = HDF5Config(features=features) - hdf5 = HDF5() - hdf5.config = config - generator = hdf5._generate_tables([[hdf5_file]]) - - tables = list(generator) - assert len(tables) == 1 + features = Features({"int32": Value("int32"), "float32": Value("float64"), "bool": Value("bool")}) + dataset = load_dataset("hdf5", data_files=[hdf5_file], split="train", features=features) - _, table = tables[0] # Check that features are properly cast - assert table.schema.field("int32").type == features["int32"].pa_type - assert table.schema.field("float32").type == features["float32"].pa_type - assert table.schema.field("bool").type == features["bool"].pa_type + assert np.asarray(dataset.data["float32"]).dtype == np.float64 + assert np.asarray(dataset.data["int32"]).dtype == np.int32 + assert np.asarray(dataset.data["bool"]).dtype == np.bool_ def test_hdf5_mismatched_lengths_error(hdf5_file_with_mismatched_lengths): """Test that mismatched dataset lengths raise an error.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) + with pytest.raises(Exception) as exc_info: + load_dataset("hdf5", data_files=[hdf5_file_with_mismatched_lengths], split="train") + + # The error can be either ValueError or DatasetGenerationError + error_str = str(exc_info.value) + if hasattr(exc_info.value, "__cause__") and exc_info.value.__cause__: + error_str += str(exc_info.value.__cause__) - with pytest.raises(ValueError, match="length.*but expected"): - for _ in generator: - pass + assert any(error_type in error_str for error_type in ["length", "Dataset 'data2' has length 3 but expected 5"]) def test_hdf5_zero_dimensions_handling(hdf5_file_with_zero_dimensions, caplog): """Test that zero dimensions are handled gracefully.""" - # Trigger feature inference - data_files = DataFilesDict({"train": [hdf5_file_with_zero_dimensions]}) - config = HDF5Config(data_files=data_files) - hdf5 = HDF5() - hdf5.config = config - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) - - # Check that features were inferred - assert hdf5.info.features is not None - - # Test that the data can be loaded - generator = hdf5._generate_tables([[hdf5_file_with_zero_dimensions]]) - tables = list(generator) - assert len(tables) == 1 + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_zero_dimensions], split="train") - _, table = tables[0] expected_columns = {"zero_dim", "zero_middle", "zero_last"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check that the data is loaded (should be empty arrays) - zero_dim_data = table["zero_dim"].to_pylist() + zero_dim_data = dataset["zero_dim"] assert len(zero_dim_data) == 3 # 3 rows assert all(len(row) == 0 for row in zero_dim_data) # Each row is empty # Check that shape info is lost - caplog.clear() - ds = load_dataset("hdf5", data_files=[hdf5_file_with_zero_dimensions], split="train") - assert all(isinstance(col, List) and col.length == -1 for col in ds.features.values()) + assert all(isinstance(col, List) and col.length == -1 for col in dataset.features.values()) # Check for the warnings assert ( @@ -575,11 +518,8 @@ def test_hdf5_zero_dimensions_handling(hdf5_file_with_zero_dimensions, caplog): def test_hdf5_empty_file_warning(empty_hdf5_file, caplog): """Test that empty files (no datasets) are skipped with a warning.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[empty_hdf5_file]]) - - tables = list(generator) - assert len(tables) == 0 # No tables should be generated + with pytest.raises(ValueError, match="corresponds to no data"): + load_dataset("hdf5", data_files=[empty_hdf5_file], split="train") # Check that warning was logged assert any( @@ -589,20 +529,13 @@ def test_hdf5_empty_file_warning(empty_hdf5_file, caplog): def test_hdf5_feature_inference(hdf5_file_with_arrays): """Test automatic feature inference from HDF5 datasets.""" - data_files = DataFilesDict({"train": [hdf5_file_with_arrays]}) - config = HDF5Config(data_files=data_files) - hdf5 = HDF5() - hdf5.config = config - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_arrays], split="train") # Check that features were inferred - assert hdf5.info.features is not None + assert dataset.features is not None # Check specific feature types - features = hdf5.info.features + features = dataset.features # (n_rows, 3, 4) -> Array2D with shape (3, 4) assert isinstance(features["matrix_2d"], Array2D) assert features["matrix_2d"].shape == (3, 4) @@ -619,20 +552,13 @@ def test_hdf5_feature_inference(hdf5_file_with_arrays): def test_hdf5_vlen_feature_inference(hdf5_file_with_vlen_arrays): """Test automatic feature inference from variable-length HDF5 datasets.""" - data_files = DataFilesDict({"train": [hdf5_file_with_vlen_arrays]}) - config = HDF5Config(data_files=data_files) - hdf5 = HDF5() - hdf5.config = config - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_vlen_arrays], split="train") # Check that features were inferred - assert hdf5.info.features is not None + assert dataset.features is not None # Check specific feature types for variable-length arrays - features = hdf5.info.features + features = dataset.features # Variable-length arrays should become List features by default (for small datasets) assert isinstance(features["vlen_ints"], List) assert isinstance(features["mixed_data"], List) @@ -646,20 +572,13 @@ def test_hdf5_vlen_feature_inference(hdf5_file_with_vlen_arrays): def test_hdf5_variable_string_feature_inference(hdf5_file_with_variable_length_strings): """Test automatic feature inference from variable-length string datasets.""" - data_files = DataFilesDict({"train": [hdf5_file_with_variable_length_strings]}) - config = HDF5Config(data_files=data_files) - hdf5 = HDF5() - hdf5.config = config - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_variable_length_strings], split="train") # Check that features were inferred - assert hdf5.info.features is not None + assert dataset.features is not None # Check specific feature types for variable-length strings - features = hdf5.info.features + features = dataset.features # Variable-length strings should become Value("string") features assert isinstance(features["var_strings"], Value) assert isinstance(features["var_bytes"], Value) @@ -669,21 +588,16 @@ def test_hdf5_variable_string_feature_inference(hdf5_file_with_variable_length_s assert features["var_bytes"].dtype == "string" -def test_hdf5_columns_features_mismatch(): - """Test that mismatched columns and features raise an error.""" - features = Features({"col1": Value("int32"), "col2": Value("float32")}) +def test_hdf5_invalid_features(hdf5_file_with_arrays): + """Test that invalid features raise an error.""" + features = Features({"fakefeature": Value("int32")}) + with pytest.raises(ValueError): + load_dataset("hdf5", data_files=[hdf5_file_with_arrays], split="train", features=features) - config = HDF5Config( - name="test", - columns=["col1", "col3"], # col3 not in features - features=features, - ) - - hdf5 = HDF5() - hdf5.config = config - - with pytest.raises(ValueError, match="must contain the same columns"): - hdf5._info() + # try with one valid and one invalid feature + features = Features({"matrix_2d": Array2D(shape=(3, 4), dtype="float32"), "fakefeature": Value("int32")}) + with pytest.raises(DatasetGenerationError): + load_dataset("hdf5", data_files=[hdf5_file_with_arrays], split="train", features=features) def test_hdf5_no_data_files_error(): @@ -698,15 +612,7 @@ def test_hdf5_no_data_files_error(): def test_hdf5_complex_numbers(hdf5_file_with_complex_data): """Test HDF5 loading with complex number datasets.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - - generator = hdf5._generate_tables([[hdf5_file_with_complex_data]]) - tables = list(generator) - - assert len(tables) == 1 - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_complex_data], split="train") # Check that complex numbers are represented as nested Features expected_columns = { @@ -714,33 +620,24 @@ def test_hdf5_complex_numbers(hdf5_file_with_complex_data): "complex_128", "complex_array", } - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check complex_64 data - complex_64_data = table["complex_64"].to_pylist() + complex_64_data = dataset["complex_64"] assert len(complex_64_data) == 4 assert complex_64_data[0] == {"real": 1.0, "imag": 2.0} assert complex_64_data[1] == {"real": 3.0, "imag": 4.0} assert complex_64_data[2] == {"real": 5.0, "imag": 6.0} assert complex_64_data[3] == {"real": 7.0, "imag": 8.0} - - assert hdf5.info.features["complex_64"]["real"].dtype == "float32" - assert hdf5.info.features["complex_64"]["imag"].dtype == "float32" - assert hdf5.info.features["complex_128"]["real"].dtype == "float64" - assert hdf5.info.features["complex_128"]["imag"].dtype == "float64" + assert np.asarray(dataset.data["complex_64"].flatten()[0]).dtype == np.float32 + assert np.asarray(dataset.data["complex_64"].flatten()[0]).dtype == np.float32 + assert np.asarray(dataset.data["complex_128"].flatten()[0]).dtype == np.float64 + assert np.asarray(dataset.data["complex_128"].flatten()[0]).dtype == np.float64 def test_hdf5_compound_types(hdf5_file_with_compound_data): """Test HDF5 loading with compound/structured datasets.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - - generator = hdf5._generate_tables([[hdf5_file_with_compound_data]]) - tables = list(generator) - - assert len(tables) == 1 - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_compound_data], split="train") # Check that compound types are represented as nested structures expected_columns = { @@ -748,10 +645,10 @@ def test_hdf5_compound_types(hdf5_file_with_compound_data): "complex_compound", "nested_compound", } - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check simple compound data - simple_compound_data = table["simple_compound"].to_pylist() + simple_compound_data = dataset["simple_compound"] assert len(simple_compound_data) == 3 assert simple_compound_data[0] == {"x": 1, "y": 2.5} assert simple_compound_data[1] == {"x": 3, "y": 4.5} @@ -760,59 +657,39 @@ def test_hdf5_compound_types(hdf5_file_with_compound_data): def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data): """Test automatic feature inference for complex datasets.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_complex_data]}) - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_complex_data], split="train") # Check that features were inferred correctly - assert hdf5.info.features is not None - features = hdf5.info.features + assert dataset.features is not None + features = dataset.features # Check complex number features assert "complex_64" in features - assert isinstance(features["complex_64"], Features) + # Complex features are represented as dict, not Features object + assert isinstance(features["complex_64"], dict) assert features["complex_64"]["real"] == Value("float32") assert features["complex_64"]["imag"] == Value("float32") def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data): """Test automatic feature inference for compound datasets.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_compound_data]}) - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_compound_data], split="train") # Check that features were inferred correctly - assert hdf5.info.features is not None - features = hdf5.info.features + assert dataset.features is not None + features = dataset.features # Check compound type features assert "simple_compound" in features - assert isinstance(features["simple_compound"], Features) + # Compound features are represented as dict, not Features object + assert isinstance(features["simple_compound"], dict) assert features["simple_compound"]["x"] == Value("int32") assert features["simple_compound"]["y"] == Value("float64") def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types): """Test HDF5 loading with mixed data types in the same file.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - - generator = hdf5._generate_tables([[hdf5_file_with_mixed_data_types]]) - tables = list(generator) - - assert len(tables) == 1 - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_mixed_data_types], split="train") # Check all expected columns are present expected_columns = { @@ -821,67 +698,54 @@ def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types): "complex_data", "compound_data", } - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check data types - assert table["regular_int"].to_pylist() == [0, 1, 2] - assert len(table["complex_data"].to_pylist()) == 3 - assert len(table["compound_data"].to_pylist()) == 3 + assert dataset["regular_int"] == [0, 1, 2] + assert len(dataset["complex_data"]) == 3 + assert len(dataset["compound_data"]) == 3 def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched_lengths): """Test that mismatched dataset lengths are ignored when the mismatched dataset is excluded via columns config.""" - hdf5 = HDF5(features=Features({"data1": Value("int32")})) - - generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) - tables = list(generator) + # Test 1: Include only the first dataset + features = Features({"data1": Value("int32")}) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_mismatched_lengths], split="train", features=features) # Should work without error since we're only including the first dataset - assert len(tables) == 1 - _, table = tables[0] - - # Check that only the specified column is present expected_columns = {"data1"} - assert set(table.column_names) == expected_columns - assert "data2" not in table.column_names + assert set(dataset.column_names) == expected_columns + assert "data2" not in dataset.column_names # Check the data - data1_values = table["data1"].to_pylist() + data1_values = dataset["data1"] assert data1_values == [0, 1, 2, 3, 4] # Test 2: Include multiple compatible datasets (all with 5 rows) - hdf5 = HDF5( - features=Features( - { - "data1": Value("int32"), - "data3": Array2D(shape=(3, 4), dtype="float32"), - "data4": Value("float64"), - "data5": Value("bool"), - "data6": Value("string"), - } - ) + features = Features( + { + "data1": Value("int32"), + "data3": Array2D(shape=(3, 4), dtype="float32"), + "data4": Value("float64"), + "data5": Value("bool"), + "data6": Value("string"), + } ) - - generator2 = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) - tables2 = list(generator2) + dataset2 = load_dataset("hdf5", data_files=[hdf5_file_with_mismatched_lengths], split="train", features=features) # Should work without error since we're excluding the mismatched dataset - assert len(tables2) == 1 - _, table2 = tables2[0] - - # Check that all specified columns are present expected_columns2 = {"data1", "data3", "data4", "data5", "data6"} - assert set(table2.column_names) == expected_columns2 - assert "data2" not in table2.column_names + assert set(dataset2.column_names) == expected_columns2 + assert "data2" not in dataset2.column_names # Check data types and values - assert table2["data1"].to_pylist() == [0, 1, 2, 3, 4] # int32 - assert len(table2["data3"].to_pylist()) == 5 # Array2D - assert len(table2["data3"].to_pylist()[0]) == 3 # 3 rows in each 2D array - assert len(table2["data3"].to_pylist()[0][0]) == 4 # 4 columns in each 2D array - np.testing.assert_allclose(table2["data4"].to_pylist(), [0.0, 0.1, 0.2, 0.3, 0.4], rtol=1e-6) # float64 - assert table2["data5"].to_pylist() == [True, False, True, False, True] # boolean - assert table2["data6"].to_pylist() == [ + assert dataset2["data1"] == [0, 1, 2, 3, 4] # int32 + assert len(dataset2["data3"]) == 5 # Array2D + assert len(dataset2["data3"][0]) == 3 # 3 rows in each 2D array + assert len(dataset2["data3"][0][0]) == 4 # 4 columns in each 2D array + np.testing.assert_allclose(dataset2["data4"], [0.0, 0.1, 0.2, 0.3, 0.4], rtol=1e-6) # float64 + assert dataset2["data5"] == [True, False, True, False, True] # boolean + assert dataset2["data6"] == [ "short", "medium length", "very long string", @@ -892,22 +756,14 @@ def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched def test_hdf5_compound_with_complex_arrays(hdf5_file_with_compound_complex_arrays): """Test HDF5 loading with compound datasets containing complex arrays.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - - generator = hdf5._generate_tables([[hdf5_file_with_compound_complex_arrays]]) - tables = list(generator) - - assert len(tables) == 1 - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_compound_complex_arrays], split="train") # Check that compound types with complex arrays are represented as nested structures expected_columns = {"compound_with_complex"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check compound data with complex arrays - compound_data = table["compound_with_complex"].to_pylist() + compound_data = dataset["compound_with_complex"] assert len(compound_data) == 3 # Check first row @@ -935,22 +791,14 @@ def test_hdf5_compound_with_complex_arrays(hdf5_file_with_compound_complex_array def test_hdf5_feature_inference_compound_complex_arrays(hdf5_file_with_compound_complex_arrays): """Test automatic feature inference for compound datasets with complex arrays.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_compound_complex_arrays]}) - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_compound_complex_arrays], split="train") # Check that features were inferred correctly - assert hdf5.info.features is not None - features = hdf5.info.features + assert dataset.features is not None + features = dataset.features # Check compound type features with complex arrays assert "compound_with_complex" in features - assert isinstance(features["compound_with_complex"], Features) # Check nested structure compound_features = features["compound_with_complex"] @@ -960,21 +808,17 @@ def test_hdf5_feature_inference_compound_complex_arrays(hdf5_file_with_compound_ assert "nested_complex" in compound_features # Check position field (nested compound) - assert isinstance(compound_features["position"], Features) assert compound_features["position"]["x"] == Value("int32") assert compound_features["position"]["y"] == Value("int32") # Check complex field (should be real/imag structure) - assert isinstance(compound_features["complex_field"], Features) assert compound_features["complex_field"]["real"] == Value("float32") assert compound_features["complex_field"]["imag"] == Value("float32") # Check complex array (should be nested real/imag structures) - assert isinstance(compound_features["complex_array"], Features) assert compound_features["complex_array"]["real"] == Array2D(shape=(2, 3), dtype="float32") assert compound_features["complex_array"]["imag"] == Array2D(shape=(2, 3), dtype="float32") # Check nested complex field - assert isinstance(compound_features["nested_complex"], Features) assert compound_features["nested_complex"]["real"] == Value("float32") assert compound_features["nested_complex"]["imag"] == Value("float32") From e58c80dc85588acd64d86f5efbdb2f61eead08d5 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 17:51:08 -0400 Subject: [PATCH 11/19] update comment --- src/datasets/packaged_modules/hdf5/hdf5.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 89acf9c1aaf..43fa5603720 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -185,9 +185,9 @@ def _convert_compound_to_nested(arr, dset) -> pa.StructArray: return _recursive_load_arrays(mock_group, features, 0, len(arr)) -# ┌───────────────────────────┐ -# │ Variable-Length Strings │ -# └───────────────────────────┘ +# ┌───────────────────┐ +# │ Variable-Length │ +# └───────────────────┘ def _is_vlen_dtype(dtype: np.dtype) -> bool: From 8b97f9ae1514a655ef0a722c1fbbef15ce9cc876 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 17:52:26 -0400 Subject: [PATCH 12/19] rename --- src/datasets/packaged_modules/hdf5/hdf5.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 43fa5603720..92066f04b82 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -321,12 +321,12 @@ def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value: return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype))) -def _first_nongroup_dataset(h5_obj, features: Features, prefix=""): +def _first_dataset(h5_obj, features: Features, prefix=""): for path, dset in h5_obj.items(): if path not in features: continue if _is_group(dset): - found = _first_nongroup_dataset(dset, features[path], prefix=f"{path}/") + found = _first_dataset(dset, features[path], prefix=f"{path}/") if found is not None: return found elif _is_dataset(dset): @@ -334,7 +334,7 @@ def _first_nongroup_dataset(h5_obj, features: Features, prefix=""): def _check_dataset_lengths(h5_obj, features: Features) -> int: - first_path = _first_nongroup_dataset(h5_obj, features) + first_path = _first_dataset(h5_obj, features) if first_path is None: return None From 979093561ba1c5df3a35a2821f47dd70bb693041 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 17:58:18 -0400 Subject: [PATCH 13/19] fix test --- tests/packaged_modules/test_hdf5.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 8e28baba9f4..2b0feee9095 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -477,15 +477,11 @@ def test_hdf5_feature_specification(hdf5_file): def test_hdf5_mismatched_lengths_error(hdf5_file_with_mismatched_lengths): """Test that mismatched dataset lengths raise an error.""" - with pytest.raises(Exception) as exc_info: + with pytest.raises(DatasetGenerationError) as exc_info: load_dataset("hdf5", data_files=[hdf5_file_with_mismatched_lengths], split="train") - # The error can be either ValueError or DatasetGenerationError - error_str = str(exc_info.value) - if hasattr(exc_info.value, "__cause__") and exc_info.value.__cause__: - error_str += str(exc_info.value.__cause__) - - assert any(error_type in error_str for error_type in ["length", "Dataset 'data2' has length 3 but expected 5"]) + assert isinstance(exc_info.value.__cause__, ValueError) + assert "3 but expected 5" in str(exc_info.value.__cause__) def test_hdf5_zero_dimensions_handling(hdf5_file_with_zero_dimensions, caplog): From 83f872fff5a676aefa546413e7cb7ac3706e8dce Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 17:59:10 -0400 Subject: [PATCH 14/19] fix copypaste --- tests/packaged_modules/test_hdf5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 2b0feee9095..8728a79a8f6 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -626,9 +626,9 @@ def test_hdf5_complex_numbers(hdf5_file_with_complex_data): assert complex_64_data[2] == {"real": 5.0, "imag": 6.0} assert complex_64_data[3] == {"real": 7.0, "imag": 8.0} assert np.asarray(dataset.data["complex_64"].flatten()[0]).dtype == np.float32 - assert np.asarray(dataset.data["complex_64"].flatten()[0]).dtype == np.float32 - assert np.asarray(dataset.data["complex_128"].flatten()[0]).dtype == np.float64 + assert np.asarray(dataset.data["complex_64"].flatten()[1]).dtype == np.float32 assert np.asarray(dataset.data["complex_128"].flatten()[0]).dtype == np.float64 + assert np.asarray(dataset.data["complex_128"].flatten()[1]).dtype == np.float64 def test_hdf5_compound_types(hdf5_file_with_compound_data): From 361056325727c17b1d0a9340a7670591c1131227 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 18:02:27 -0400 Subject: [PATCH 15/19] check complex splits data into real/imag --- tests/packaged_modules/test_hdf5.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 8728a79a8f6..a27a3707eb5 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -625,10 +625,16 @@ def test_hdf5_complex_numbers(hdf5_file_with_complex_data): assert complex_64_data[1] == {"real": 3.0, "imag": 4.0} assert complex_64_data[2] == {"real": 5.0, "imag": 6.0} assert complex_64_data[3] == {"real": 7.0, "imag": 8.0} + assert np.asarray(dataset.data["complex_64"].flatten()[0]).dtype == np.float32 assert np.asarray(dataset.data["complex_64"].flatten()[1]).dtype == np.float32 + assert (np.asarray(dataset.data["complex_64"].flatten()[0]) == np.array([1, 3, 5, 7], dtype=np.float32)).all() + assert (np.asarray(dataset.data["complex_64"].flatten()[1]) == np.array([2, 4, 6, 8], dtype=np.float32)).all() + assert np.asarray(dataset.data["complex_128"].flatten()[0]).dtype == np.float64 assert np.asarray(dataset.data["complex_128"].flatten()[1]).dtype == np.float64 + assert (np.asarray(dataset.data["complex_128"].flatten()[0]) == np.array([1.5, 3.5, 5.5, 7.5], dtype=np.float64)).all() + assert (np.asarray(dataset.data["complex_128"].flatten()[1]) == np.array([2.5, 4.5, 6.5, 8.5], dtype=np.float64)).all() def test_hdf5_compound_types(hdf5_file_with_compound_data): From f5fb3e5f3628c04ba6f74a2de54e32b22a72fa98 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 19:08:36 -0400 Subject: [PATCH 16/19] fix prefix, empty groups --- src/datasets/packaged_modules/hdf5/hdf5.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 92066f04b82..a6915661672 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -217,9 +217,13 @@ def _recursive_infer_features(h5_obj) -> Features: features_dict = {} for path, dset in h5_obj.items(): if _is_group(dset): - features_dict[path] = _recursive_infer_features(dset) + features = _recursive_infer_features(dset) + if features: + features_dict[path] = features elif _is_dataset(dset): - features_dict[path] = _infer_feature(dset) + features = _infer_feature(dset) + if features: + features_dict[path] = features return Features(features_dict) @@ -326,7 +330,7 @@ def _first_dataset(h5_obj, features: Features, prefix=""): if path not in features: continue if _is_group(dset): - found = _first_dataset(dset, features[path], prefix=f"{path}/") + found = _first_dataset(dset, features[path], prefix=f"{prefix}{path}/") if found is not None: return found elif _is_dataset(dset): From eea88d919dd2f6957cb6f4789ba6868289d8a4ff Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 19:52:56 -0400 Subject: [PATCH 17/19] handle chunkedarray gracefully --- src/datasets/packaged_modules/hdf5/hdf5.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index a6915661672..c4eb4cdec9d 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -283,8 +283,17 @@ def _recursive_load_arrays(h5_obj, features: Features, start: int, end: int): return pa.Table.from_pydict(batch_dict) if batch_dict: - keys, values = zip(*batch_dict.items()) - return pa.StructArray.from_arrays(values, names=keys) + should_chunk, keys, values = False, [], [] + for k, v in batch_dict.items(): + if isinstance(v, pa.ChunkedArray): + should_chunk = True + v = v.combine_chunks() + keys.append(k) + values.append(v) + + sarr = pa.StructArray.from_arrays(values, names=keys) + return pa.chunked_array(sarr) if should_chunk else sarr + # ┌─────────────┐ From c331e4997bce846c137ed46b30dfe2555aebe016 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 22:20:07 -0400 Subject: [PATCH 18/19] fix empty file test --- tests/packaged_modules/test_hdf5.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index a27a3707eb5..9533ffad930 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -512,10 +512,9 @@ def test_hdf5_zero_dimensions_handling(hdf5_file_with_zero_dimensions, caplog): ) -def test_hdf5_empty_file_warning(empty_hdf5_file, caplog): +def test_hdf5_empty_file_warning(empty_hdf5_file, hdf5_file_with_arrays, caplog): """Test that empty files (no datasets) are skipped with a warning.""" - with pytest.raises(ValueError, match="corresponds to no data"): - load_dataset("hdf5", data_files=[empty_hdf5_file], split="train") + load_dataset("hdf5", data_files=[hdf5_file_with_arrays, empty_hdf5_file], split="train") # Check that warning was logged assert any( From bd80590dd29d70916f8964d96578ce29739925e2 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 21 Aug 2025 22:21:05 -0400 Subject: [PATCH 19/19] format --- src/datasets/packaged_modules/hdf5/hdf5.py | 1 - tests/packaged_modules/test_hdf5.py | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index c4eb4cdec9d..fb9100e1a0a 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -293,7 +293,6 @@ def _recursive_load_arrays(h5_obj, features: Features, start: int, end: int): sarr = pa.StructArray.from_arrays(values, names=keys) return pa.chunked_array(sarr) if should_chunk else sarr - # ┌─────────────┐ diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 9533ffad930..5bee3122274 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -632,8 +632,12 @@ def test_hdf5_complex_numbers(hdf5_file_with_complex_data): assert np.asarray(dataset.data["complex_128"].flatten()[0]).dtype == np.float64 assert np.asarray(dataset.data["complex_128"].flatten()[1]).dtype == np.float64 - assert (np.asarray(dataset.data["complex_128"].flatten()[0]) == np.array([1.5, 3.5, 5.5, 7.5], dtype=np.float64)).all() - assert (np.asarray(dataset.data["complex_128"].flatten()[1]) == np.array([2.5, 4.5, 6.5, 8.5], dtype=np.float64)).all() + assert ( + np.asarray(dataset.data["complex_128"].flatten()[0]) == np.array([1.5, 3.5, 5.5, 7.5], dtype=np.float64) + ).all() + assert ( + np.asarray(dataset.data["complex_128"].flatten()[1]) == np.array([2.5, 4.5, 6.5, 8.5], dtype=np.float64) + ).all() def test_hdf5_compound_types(hdf5_file_with_compound_data):