diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index 36858c7dab3..fb9100e1a0a 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -1,7 +1,6 @@ import itertools -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional -from typing import List as ListT +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Optional import numpy as np import pyarrow as pa @@ -19,7 +18,7 @@ _ArrayXD, _arrow_to_datasets_dtype, ) -from datasets.table import table_cast +from datasets.table import cast_table_to_features if TYPE_CHECKING: @@ -35,12 +34,8 @@ class HDF5Config(datasets.BuilderConfig): """BuilderConfig for HDF5.""" batch_size: Optional[int] = None - columns: Optional[ListT[str]] = None features: Optional[datasets.Features] = None - def __post_init__(self): - super().__post_init__() - class HDF5(datasets.ArrowBasedBuilder): """ArrowBasedBuilder that converts HDF5 files to Arrow tables using the HF extension types.""" @@ -48,15 +43,6 @@ class HDF5(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = HDF5Config def _info(self): - if ( - self.config.columns is not None - and self.config.features is not None - and set(self.config.columns) != set(self.config.features) - ): - raise ValueError( - "The columns and features argument must contain the same columns, but got ", - f"{self.config.columns} and {self.config.features}", - ) return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): @@ -76,35 +62,11 @@ def _split_generators(self, dl_manager): if self.info.features is None: for first_file in itertools.chain.from_iterable(files): with h5py.File(first_file, "r") as h5: - dataset_map = _traverse_datasets(h5) - features_dict = {} - - for path, dset in dataset_map.items(): - if _is_complex_dtype(dset.dtype): - complex_features = _create_complex_features(path, dset) - features_dict.update(complex_features) - elif _is_compound_dtype(dset.dtype): - compound_features = _create_compound_features(path, dset) - features_dict.update(compound_features) - elif _is_vlen_string_dtype(dset.dtype): - features_dict[path] = Value("string") - else: - feat = _infer_feature_from_dataset(dset) - features_dict[path] = feat - self.info.features = datasets.Features(features_dict) + self.info.features = _recursive_infer_features(h5) break splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) - if self.config.columns is not None and set(self.config.columns) != set(self.info.features): - self.info.features = datasets.Features( - {col: feat for col, feat in self.info.features.items() if col in self.config.columns} - ) return splits - def _cast_table(self, pa_table: pa.Table) -> pa.Table: - if self.info.features is not None: - pa_table = table_cast(pa_table, self.info.features.arrow_schema) - return pa_table - def _generate_tables(self, files): import h5py @@ -112,133 +74,70 @@ def _generate_tables(self, files): for file_idx, file in enumerate(itertools.chain.from_iterable(files)): try: with h5py.File(file, "r") as h5: - dataset_map = _traverse_datasets(h5) - if not dataset_map: - logger.warning(f"File '{file}' contains no data, skipping...") + # Infer features and lengths from first file + if self.info.features is None: + self.info.features = _recursive_infer_features(h5) + num_rows = _check_dataset_lengths(h5, self.info.features) + if num_rows is None: + logger.warning(f"File {file} contains no data, skipping...") continue - - if self.config.columns is not None: - filtered_dataset_map = { - path: dset for path, dset in dataset_map.items() if path in self.config.columns - } - if not filtered_dataset_map: - logger.warning( - f"No datasets match the specified columns {self.config.columns}, skipping..." - ) - continue - dataset_map = filtered_dataset_map - - # Sanity-check lengths for selected datasets - first_dset = next(iter(dataset_map.values())) - num_rows = first_dset.shape[0] - for path, dset in dataset_map.items(): - if dset.shape[0] != num_rows: - raise ValueError( - f"Dataset '{path}' length {dset.shape[0]} differs from {num_rows} in file '{file}'" - ) effective_batch = batch_size_cfg or self._writer_batch_size or num_rows for start in range(0, num_rows, effective_batch): end = min(start + effective_batch, num_rows) - batch_dict = {} - for path, dset in dataset_map.items(): - arr = dset[start:end] - - # Handle variable-length arrays - if _is_vlen_string_dtype(dset.dtype): - logger.debug( - f"Converting variable-length string data for '{path}' (shape: {arr.shape})" - ) - batch_dict[path] = _convert_vlen_string_to_array(arr) - elif ( - hasattr(dset.dtype, "metadata") - and dset.dtype.metadata - and "vlen" in dset.dtype.metadata - ): - # Handle other variable-length types (non-strings) - pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) - batch_dict[path] = pa_arr - elif _is_complex_dtype(dset.dtype): - batch_dict.update(_convert_complex_to_nested(path, arr, dset)) - elif _is_compound_dtype(dset.dtype): - batch_dict.update(_convert_compound_to_nested(path, arr, dset)) - elif dset.dtype.kind == "O": - raise ValueError( - f"Object dtype dataset '{path}' is not supported. " - f"For variable-length data, please use h5py.vlen_dtype() " - f"when creating the HDF5 file. " - f"See: https://docs.h5py.org/en/stable/special.html#variable-length-strings" - ) - else: - # If any non-batch dimension is zero, emit an unsized pa.list_ - # to avoid creating FixedSizeListArray with list_size=0. - if any(dim == 0 for dim in dset.shape[1:]): - inner_type = pa.from_numpy_dtype(dset.dtype) - pa_arr = pa.array([[] for _ in arr], type=pa.list_(inner_type)) - else: - pa_arr = datasets.features.features.numpy_to_pyarrow_listarray(arr) - batch_dict[path] = pa_arr - pa_table = pa.Table.from_pydict(batch_dict) - yield f"{file_idx}_{start}", self._cast_table(pa_table) + pa_table = _recursive_load_arrays(h5, self.info.features, start, end) + if pa_table is None: + logger.warning(f"File {file} contains no data, skipping...") + continue + yield f"{file_idx}_{start}", cast_table_to_features(pa_table, self.info.features) except ValueError as e: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise -def _traverse_datasets(h5_obj, prefix: str = "") -> Dict[str, "h5py.Dataset"]: - import h5py - - mapping: Dict[str, h5py.Dataset] = {} - - def collect_datasets(name, obj): - if isinstance(obj, h5py.Dataset): - full_path = f"{prefix}{name}" if prefix else name - mapping[full_path] = obj - - h5_obj.visititems(collect_datasets) - return mapping - - # ┌───────────┐ # │ Complex │ # └───────────┘ def _is_complex_dtype(dtype: np.dtype) -> bool: - """Check if dtype is a complex number type.""" - return dtype.kind == "c" + if dtype.kind == "c": + return True + if dtype.subdtype is not None: + return _is_complex_dtype(dtype.subdtype[0]) + return False -def _create_complex_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Features]: - """Create Features for complex data with real and imaginary parts `real` and `imag`. +def _create_complex_features(dset) -> Features: + if dset.dtype.subdtype is not None: + dtype, data_shape = dset.dtype.subdtype + else: + data_shape = dset.shape[1:] + dtype = dset.dtype + + if dtype == np.complex64: + # two float32s + value_type = Value("float32") + elif dtype == np.complex128: + # two float64s + value_type = Value("float64") + else: + logger.warning(f"Found complex dtype {dtype} that is not supported. Converting to float64...") + value_type = Value("float64") - NOTE: Always uses float64 for the real and imaginary parts. - """ - logger.info( - f"Complex dataset '{base_path}' (dtype: {dset.dtype}) represented as nested structure with 'real' and 'imag' fields" - ) - nested_features = Features( + return Features( { - "real": Value("float64"), - "imag": Value("float64"), + "real": _create_sized_feature_impl(data_shape, value_type), + "imag": _create_sized_feature_impl(data_shape, value_type), } ) - return {base_path: nested_features} - -def _convert_complex_to_nested(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: - """Convert complex to Features with real and imaginary parts `real` and `imag`.""" - result = {} - - def _convert_complex_scalar(complex_val): - """Convert a complex scalar to a dictionary.""" - if complex_val.size == 1: - return {"real": float(complex_val.item().real), "imag": float(complex_val.item().imag)} - else: - # For multi-dimensional arrays, convert to list - return {"real": complex_val.real.tolist(), "imag": complex_val.imag.tolist()} - result[base_path] = pa.array([_convert_complex_scalar(complex_val) for complex_val in arr]) - return result +def _convert_complex_to_nested(arr: np.ndarray) -> pa.StructArray: + data = { + "real": datasets.features.features.numpy_to_pyarrow_listarray(arr.real), + "imag": datasets.features.features.numpy_to_pyarrow_listarray(arr.imag), + } + return pa.StructArray.from_arrays([data["real"], data["imag"]], names=["real", "imag"]) # ┌────────────┐ @@ -247,94 +146,66 @@ def _convert_complex_scalar(complex_val): def _is_compound_dtype(dtype: np.dtype) -> bool: - """Check if dtype is a compound/structured type.""" - return dtype.names is not None + return dtype.kind == "V" -class _MockDataset: - def __init__(self, dtype): - self.dtype = dtype - self.names = dtype.names +@dataclass +class _CompoundGroup: + dset: "h5py.Dataset" + data: np.ndarray = None + def items(self): + for field_name in self.dset.dtype.names: + field_dtype = self.dset.dtype[field_name] + yield field_name, _CompoundField(self.data, field_name, field_dtype) -def _create_compound_features(base_path: str, dset: "h5py.Dataset") -> Dict[str, Features]: - """Create nested features for compound data with field names as keys.""" - field_names = list(dset.dtype.names) - logger.info( - f"Compound dataset '{base_path}' (dtype: {dset.dtype}) represented as nested Features with fields: {field_names}" - ) - nested_features_dict = {} - for field_name in field_names: - field_dtype = dset.dtype[field_name] - - if _is_complex_dtype(field_dtype): - nested_features_dict[field_name] = Features( - { - "real": Value("float64"), - "imag": Value("float64"), - } - ) - elif _is_compound_dtype(field_dtype): - mock_dset = _MockDataset(field_dtype) - nested_features_dict[field_name] = _create_compound_features(field_name, mock_dset)[field_name] - else: - nested_features_dict[field_name] = _np_to_pa_to_hf_value(field_dtype) +@dataclass +class _CompoundField: + data: Optional[np.ndarray] + name: str + dtype: np.dtype + shape: tuple[int, ...] = field(init=False) - nested_features = Features(nested_features_dict) - return {base_path: nested_features} + def __post_init__(self): + self.shape = (len(self.data) if self.data is not None else 0,) + self.dtype.shape + def __getitem__(self, key): + return self.data[key][self.name] -def _convert_compound_to_nested(base_path: str, arr: np.ndarray, dset: "h5py.Dataset") -> Dict[str, pa.Array]: - """Convert compound array to nested structure with field names as keys.""" - result = {} - def _convert_compound_recursive(compound_arr, compound_dtype): - """Recursively convert compound array to nested structure.""" - nested_data = [] - for row in compound_arr: - row_dict = {} - for field_name in compound_dtype.names: - field_dtype = compound_dtype[field_name] - field_data = row[field_name] +def _create_compound_features(dset) -> Features: + mock_group = _CompoundGroup(dset) + return _recursive_infer_features(mock_group) - if _is_complex_dtype(field_dtype): - row_dict[field_name] = {"real": float(field_data.real), "imag": float(field_data.imag)} - elif _is_compound_dtype(field_dtype): - row_dict[field_name] = _convert_compound_recursive([field_data], field_dtype)[0] - else: - row_dict[field_name] = field_data.item() if field_data.size == 1 else field_data.tolist() - nested_data.append(row_dict) - return nested_data - result[base_path] = pa.array(_convert_compound_recursive(arr, dset.dtype)) - return result +def _convert_compound_to_nested(arr, dset) -> pa.StructArray: + mock_group = _CompoundGroup(dset, data=arr) + features = _create_compound_features(dset) + return _recursive_load_arrays(mock_group, features, 0, len(arr)) -# ┌───────────────────────────┐ -# │ Variable-Length Strings │ -# └───────────────────────────┘ +# ┌───────────────────┐ +# │ Variable-Length │ +# └───────────────────┘ -def _is_vlen_string_dtype(dtype: np.dtype) -> bool: - """Check if dtype is a variable-length string type.""" - if hasattr(dtype, "metadata") and dtype.metadata and "vlen" in dtype.metadata: - vlen_dtype = dtype.metadata["vlen"] - return vlen_dtype in (str, bytes) +def _is_vlen_dtype(dtype: np.dtype) -> bool: + if dtype.metadata and "vlen" in dtype.metadata: + return True return False -def _convert_vlen_string_to_array(arr: np.ndarray) -> pa.Array: - list_of_items = [] - for item in arr: - if isinstance(item, bytes): - logger.info("Assuming variable-length bytes are utf-8 encoded strings") - list_of_items.append(item.decode("utf-8")) - elif isinstance(item, str): - list_of_items.append(item) - else: - raise ValueError(f"Unsupported variable-length string type: {type(item)}") - return pa.array(list_of_items) +def _create_vlen_features(dset) -> Features: + vlen_dtype = dset.dtype.metadata["vlen"] + if vlen_dtype in (str, bytes): + return Value("string") + inner_feature = _np_to_pa_to_hf_value(vlen_dtype) + return List(inner_feature) + + +def _convert_vlen_to_array(arr: np.ndarray) -> pa.Array: + return datasets.features.features.numpy_to_pyarrow_listarray(arr) # ┌───────────┐ @@ -342,20 +213,104 @@ def _convert_vlen_string_to_array(arr: np.ndarray) -> pa.Array: # └───────────┘ -def _infer_feature_from_dataset(dset: "h5py.Dataset"): - # non-string varlen - if hasattr(dset.dtype, "metadata") and dset.dtype.metadata and "vlen" in dset.dtype.metadata: - vlen_dtype = dset.dtype.metadata["vlen"] - inner_feature = _np_to_pa_to_hf_value(vlen_dtype) - return List(inner_feature) +def _recursive_infer_features(h5_obj) -> Features: + features_dict = {} + for path, dset in h5_obj.items(): + if _is_group(dset): + features = _recursive_infer_features(dset) + if features: + features_dict[path] = features + elif _is_dataset(dset): + features = _infer_feature(dset) + if features: + features_dict[path] = features + + return Features(features_dict) + + +def _infer_feature(dset): + if _is_complex_dtype(dset.dtype): + return _create_complex_features(dset) + elif _is_compound_dtype(dset.dtype) or dset.dtype.kind == "V": + return _create_compound_features(dset) + elif _is_vlen_dtype(dset.dtype): + return _create_vlen_features(dset) + return _create_sized_feature(dset) + + +def _load_array(dset, path: str, start: int, end: int) -> pa.Array: + arr = dset[start:end] + + if _is_vlen_dtype(dset.dtype): + return _convert_vlen_to_array(arr) + elif _is_complex_dtype(dset.dtype): + return _convert_complex_to_nested(arr) + elif _is_compound_dtype(dset.dtype): + return _convert_compound_to_nested(arr, dset) + elif dset.dtype.kind == "O": + raise ValueError( + f"Object dtype dataset '{path}' is not supported. " + f"For variable-length data, please use h5py.vlen_dtype() " + f"when creating the HDF5 file. " + f"See: https://docs.h5py.org/en/stable/special.html#variable-length-strings" + ) + else: + # If any non-batch dimension is zero, emit an unsized pa.list_ + # to avoid creating FixedSizeListArray with list_size=0. + if any(dim == 0 for dim in dset.shape[1:]): + inner_type = pa.from_numpy_dtype(dset.dtype) + return pa.array([[] for _ in arr], type=pa.list_(inner_type)) + else: + return datasets.features.features.numpy_to_pyarrow_listarray(arr) + + +def _recursive_load_arrays(h5_obj, features: Features, start: int, end: int): + batch_dict = {} + for path, dset in h5_obj.items(): + if path not in features: + continue + if _is_group(dset): + arr = _recursive_load_arrays(dset, features[path], start, end) + elif _is_dataset(dset): + arr = _load_array(dset, path, start, end) + else: + raise ValueError(f"Unexpected type {type(dset)}") - value_feature = _np_to_pa_to_hf_value(dset.dtype) - dtype_str = value_feature.dtype + if arr is not None: + batch_dict[path] = arr + + if _is_file(h5_obj): + return pa.Table.from_pydict(batch_dict) + + if batch_dict: + should_chunk, keys, values = False, [], [] + for k, v in batch_dict.items(): + if isinstance(v, pa.ChunkedArray): + should_chunk = True + v = v.combine_chunks() + keys.append(k) + values.append(v) + + sarr = pa.StructArray.from_arrays(values, names=keys) + return pa.chunked_array(sarr) if should_chunk else sarr + + +# ┌─────────────┐ +# │ Utilities │ +# └─────────────┘ + +def _create_sized_feature(dset): dset_shape = dset.shape[1:] + value_feature = _np_to_pa_to_hf_value(dset.dtype) + return _create_sized_feature_impl(dset_shape, value_feature) + + +def _create_sized_feature_impl(dset_shape, value_feature): + dtype_str = value_feature.dtype if any(dim == 0 for dim in dset_shape): logger.warning( - f"HDF5 to Arrow: Found a dataset named '{dset.name}' with shape {dset_shape} and dtype {dtype_str} that has a dimension with size 0. Shape information will be lost in the conversion to List({value_feature})." + f"HDF5 to Arrow: Found a dataset with shape {dset_shape} and dtype {dtype_str} that has a dimension with size 0. Shape information will be lost in the conversion to List({value_feature})." ) return List(value_feature) @@ -370,20 +325,65 @@ def _infer_feature_from_dataset(dset: "h5py.Dataset"): raise TypeError(f"Array{rank}D not supported. Maximum 5 dimensions allowed.") +def _sized_arrayxd(rank: int): + return {2: Array2D, 3: Array3D, 4: Array4D, 5: Array5D}[rank] + + +def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value: + return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype))) + + +def _first_dataset(h5_obj, features: Features, prefix=""): + for path, dset in h5_obj.items(): + if path not in features: + continue + if _is_group(dset): + found = _first_dataset(dset, features[path], prefix=f"{prefix}{path}/") + if found is not None: + return found + elif _is_dataset(dset): + return f"{prefix}{path}" + + +def _check_dataset_lengths(h5_obj, features: Features) -> int: + first_path = _first_dataset(h5_obj, features) + if first_path is None: + return None + + num_rows = h5_obj[first_path].shape[0] + for path, dset in h5_obj.items(): + if path not in features: + continue + if _is_dataset(dset): + if dset.shape[0] != num_rows: + raise ValueError(f"Dataset '{path}' has length {dset.shape[0]} but expected {num_rows}") + return num_rows + + +def _is_group(h5_obj) -> bool: + import h5py + + return isinstance(h5_obj, h5py.Group) or isinstance(h5_obj, _CompoundGroup) + + +def _is_dataset(h5_obj) -> bool: + import h5py + + return isinstance(h5_obj, h5py.Dataset) or isinstance(h5_obj, _CompoundField) + + +def _is_file(h5_obj) -> bool: + import h5py + + return isinstance(h5_obj, h5py.File) + + def _has_zero_dimensions(feature): if isinstance(feature, _ArrayXD): return any(dim == 0 for dim in feature.shape) - elif isinstance(feature, List): # also gets regular List + elif isinstance(feature, List): return feature.length == 0 or _has_zero_dimensions(feature.feature) elif isinstance(feature, LargeList): return _has_zero_dimensions(feature.feature) else: return False - - -def _sized_arrayxd(rank: int): - return {2: Array2D, 3: Array3D, 4: Array4D, 5: Array5D}[rank] - - -def _np_to_pa_to_hf_value(numpy_dtype: np.dtype) -> Value: - return Value(dtype=_arrow_to_datasets_dtype(pa.from_numpy_dtype(numpy_dtype))) diff --git a/tests/packaged_modules/test_hdf5.py b/tests/packaged_modules/test_hdf5.py index 06329a9c430..5bee3122274 100644 --- a/tests/packaged_modules/test_hdf5.py +++ b/tests/packaged_modules/test_hdf5.py @@ -4,8 +4,8 @@ from datasets import Array2D, Array3D, Array4D, Features, List, Value, load_dataset from datasets.builder import InvalidConfigName -from datasets.data_files import DataFilesDict, DataFilesList -from datasets.download.streaming_download_manager import StreamingDownloadManager +from datasets.data_files import DataFilesList +from datasets.exceptions import DatasetGenerationError from datasets.packaged_modules.hdf5.hdf5 import HDF5, HDF5Config @@ -183,6 +183,52 @@ def hdf5_file_with_compound_data(tmp_path): return str(filename) +@pytest.fixture +def hdf5_file_with_compound_complex_arrays(tmp_path): + """Create an HDF5 file with compound datasets containing complex arrays.""" + filename = tmp_path / "compound_complex_arrays.h5" + + with h5py.File(filename, "w") as f: + # Compound type with complex arrays + dt_complex_arrays = np.dtype( + [ + ("position", [("x", "i4"), ("y", "i4")]), + ("complex_field", "c8"), + ("complex_array", "c8", (2, 3)), + ("nested_complex", [("real", "f4"), ("imag", "f4")]), + ] + ) + + # Create data with complex numbers + compound_data = np.array( + [ + ( + (1, 2), + 1.0 + 2.0j, + [[1.0 + 2.0j, 3.0 + 4.0j, 5.0 + 6.0j], [7.0 + 8.0j, 9.0 + 10.0j, 11.0 + 12.0j]], + (1.5, 2.5), + ), + ( + (3, 4), + 3.0 + 4.0j, + [[13.0 + 14.0j, 15.0 + 16.0j, 17.0 + 18.0j], [19.0 + 20.0j, 21.0 + 22.0j, 23.0 + 24.0j]], + (3.5, 4.5), + ), + ( + (5, 6), + 5.0 + 6.0j, + [[25.0 + 26.0j, 27.0 + 28.0j, 29.0 + 30.0j], [31.0 + 32.0j, 33.0 + 34.0j, 35.0 + 36.0j]], + (5.5, 6.5), + ), + ], + dtype=dt_complex_arrays, + ) + + f.create_dataset("compound_with_complex", data=compound_data) + + return str(filename) + + @pytest.fixture def hdf5_file_with_mismatched_lengths(tmp_path): """Create an HDF5 file with datasets of different lengths (should raise error).""" @@ -272,61 +318,49 @@ def test_config_raises_when_invalid_data_files(data_files): def test_hdf5_basic_functionality(hdf5_file): """Test basic HDF5 loading with simple numeric datasets.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file]]) - - tables = list(generator) - assert len(tables) == 1 + dataset = load_dataset("hdf5", data_files=[hdf5_file], split="train") - _, table = tables[0] - assert "int32" in table.column_names - assert "float32" in table.column_names - assert "bool" in table.column_names + assert "int32" in dataset.column_names + assert "float32" in dataset.column_names + assert "bool" in dataset.column_names - # Check data - int32_data = table["int32"].to_pylist() - assert int32_data == [0, 1, 2, 3, 4] + assert np.asarray(dataset.data["int32"]).dtype == np.int32 + assert np.asarray(dataset.data["float32"]).dtype == np.float32 + assert np.asarray(dataset.data["bool"]).dtype == np.bool_ - float32_data = table["float32"].to_pylist() + assert dataset["int32"] == [0, 1, 2, 3, 4] + float32_data = dataset["float32"] expected_float32 = [0.0, 0.1, 0.2, 0.3, 0.4] np.testing.assert_allclose(float32_data, expected_float32, rtol=1e-6) def test_hdf5_nested_groups(hdf5_file_with_groups): """Test HDF5 loading with nested groups.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_groups]]) - - tables = list(generator) - assert len(tables) == 1 + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_groups], split="train") - _, table = tables[0] - expected_columns = {"root_data", "group1/group_data", "group1/subgroup/sub_data"} - assert set(table.column_names) == expected_columns + expected_columns = {"root_data", "group1"} + assert set(dataset.column_names) == expected_columns # Check data - root_data = table["root_data"].to_pylist() + root_data = dataset["root_data"] + group1_data = dataset["group1"] assert root_data == [0, 1, 2] - - group_data = table["group1/group_data"].to_pylist() - expected_group_data = [0.0, 1.0, 2.0] - np.testing.assert_allclose(group_data, expected_group_data, rtol=1e-6) + assert group1_data == [ + {"group_data": 0.0, "subgroup": {"sub_data": 0}}, + {"group_data": 1.0, "subgroup": {"sub_data": 1}}, + {"group_data": 2.0, "subgroup": {"sub_data": 2}}, + ] def test_hdf5_multi_dimensional_arrays(hdf5_file_with_arrays): """Test HDF5 loading with multi-dimensional arrays.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_arrays]]) - - tables = list(generator) - assert len(tables) == 1 + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_arrays], split="train") - _, table = tables[0] expected_columns = {"matrix_2d", "tensor_3d", "tensor_4d", "tensor_5d", "vector_1d"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check shapes - matrix_2d = table["matrix_2d"].to_pylist() + matrix_2d = dataset["matrix_2d"] assert len(matrix_2d) == 4 # 4 rows assert len(matrix_2d[0]) == 3 # 3 rows in each matrix assert len(matrix_2d[0][0]) == 4 # 4 columns in each matrix @@ -334,18 +368,13 @@ def test_hdf5_multi_dimensional_arrays(hdf5_file_with_arrays): def test_hdf5_vlen_arrays(hdf5_file_with_vlen_arrays): """Test HDF5 loading with variable-length arrays (int32).""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_vlen_arrays]]) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_vlen_arrays], split="train") - tables = list(generator) - assert len(tables) == 1 - - _, table = tables[0] expected_columns = {"vlen_ints", "mixed_data"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check vlen_ints data - vlen_ints = table["vlen_ints"].to_pylist() + vlen_ints = dataset["vlen_ints"] assert len(vlen_ints) == 4 assert vlen_ints[0] == [1, 2, 3] assert vlen_ints[1] == [4, 5] @@ -353,7 +382,7 @@ def test_hdf5_vlen_arrays(hdf5_file_with_vlen_arrays): assert vlen_ints[3] == [10] # Check mixed_data (with None values) - mixed_data = table["mixed_data"].to_pylist() + mixed_data = dataset["mixed_data"] assert len(mixed_data) == 4 assert mixed_data[0] == [1, 2, 3] assert mixed_data[1] == [] # Empty array instead of None @@ -363,18 +392,12 @@ def test_hdf5_vlen_arrays(hdf5_file_with_vlen_arrays): def test_hdf5_variable_length_strings(hdf5_file_with_variable_length_strings): """Test HDF5 loading with variable-length string datasets.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_variable_length_strings]]) - - tables = list(generator) - assert len(tables) == 1 - - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_variable_length_strings], split="train") expected_columns = {"var_strings", "var_bytes"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check variable-length strings (converted to strings for usability) - var_strings = table["var_strings"].to_pylist() + var_strings = dataset["var_strings"] assert len(var_strings) == 4 assert var_strings[0] == "short" assert var_strings[1] == "medium length string" @@ -382,7 +405,7 @@ def test_hdf5_variable_length_strings(hdf5_file_with_variable_length_strings): assert var_strings[3] == "tiny" # Check variable-length bytes (converted to strings for usability) - var_bytes = table["var_bytes"].to_pylist() + var_bytes = dataset["var_bytes"] assert len(var_bytes) == 4 assert var_bytes[0] == "short" assert var_bytes[1] == "medium length bytes" @@ -392,21 +415,15 @@ def test_hdf5_variable_length_strings(hdf5_file_with_variable_length_strings): def test_hdf5_different_dtypes(hdf5_file_with_different_dtypes): """Test HDF5 loading with various numeric dtypes.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_different_dtypes]]) - - tables = list(generator) - assert len(tables) == 1 - - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_different_dtypes], split="train") expected_columns = {"int8", "int16", "int64", "uint8", "uint16", "uint32", "uint64", "float16", "float64", "bytes"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check specific dtypes - int8_data = table["int8"].to_pylist() + int8_data = dataset["int8"] assert int8_data == [0, 1, 2] - bytes_data = table["bytes"].to_pylist() + bytes_data = dataset["bytes"] assert bytes_data == [b"row_0", b"row_1", b"row_2"] @@ -432,82 +449,55 @@ def test_hdf5_batch_processing(hdf5_file): def test_hdf5_column_filtering(hdf5_file_with_groups): """Test HDF5 loading with column filtering.""" - config = HDF5Config(columns=["root_data", "group1/group_data"]) - hdf5 = HDF5() - hdf5.config = config - generator = hdf5._generate_tables([[hdf5_file_with_groups]]) + features = Features({"root_data": Value("int32"), "group1": Features({"group_data": Value("float32")})}) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_groups], split="train", features=features) - tables = list(generator) - assert len(tables) == 1 + expected_columns = {"root_data", "group1"} + assert set(dataset.column_names) == expected_columns - _, table = tables[0] - expected_columns = {"root_data", "group1/group_data"} - assert set(table.column_names) == expected_columns - assert "group1/subgroup/sub_data" not in table.column_names + # Check that subgroup is filtered out + group1_data = dataset["group1"] + assert group1_data == [ + {"group_data": 0.0}, + {"group_data": 1.0}, + {"group_data": 2.0}, + ] def test_hdf5_feature_specification(hdf5_file): """Test HDF5 loading with explicit feature specification.""" - features = Features({"int32": Value("int32"), "float32": Value("float32"), "bool": Value("bool")}) + features = Features({"int32": Value("int32"), "float32": Value("float64"), "bool": Value("bool")}) + dataset = load_dataset("hdf5", data_files=[hdf5_file], split="train", features=features) - config = HDF5Config(features=features) - hdf5 = HDF5() - hdf5.config = config - generator = hdf5._generate_tables([[hdf5_file]]) - - tables = list(generator) - assert len(tables) == 1 - - _, table = tables[0] # Check that features are properly cast - assert table.schema.field("int32").type == features["int32"].pa_type - assert table.schema.field("float32").type == features["float32"].pa_type - assert table.schema.field("bool").type == features["bool"].pa_type + assert np.asarray(dataset.data["float32"]).dtype == np.float64 + assert np.asarray(dataset.data["int32"]).dtype == np.int32 + assert np.asarray(dataset.data["bool"]).dtype == np.bool_ def test_hdf5_mismatched_lengths_error(hdf5_file_with_mismatched_lengths): """Test that mismatched dataset lengths raise an error.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) + with pytest.raises(DatasetGenerationError) as exc_info: + load_dataset("hdf5", data_files=[hdf5_file_with_mismatched_lengths], split="train") - with pytest.raises(ValueError, match="length.*differs from"): - for _ in generator: - pass + assert isinstance(exc_info.value.__cause__, ValueError) + assert "3 but expected 5" in str(exc_info.value.__cause__) def test_hdf5_zero_dimensions_handling(hdf5_file_with_zero_dimensions, caplog): """Test that zero dimensions are handled gracefully.""" - # Trigger feature inference - data_files = DataFilesDict({"train": [hdf5_file_with_zero_dimensions]}) - config = HDF5Config(data_files=data_files) - hdf5 = HDF5() - hdf5.config = config - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) - - # Check that features were inferred - assert hdf5.info.features is not None + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_zero_dimensions], split="train") - # Test that the data can be loaded - generator = hdf5._generate_tables([[hdf5_file_with_zero_dimensions]]) - tables = list(generator) - assert len(tables) == 1 - - _, table = tables[0] expected_columns = {"zero_dim", "zero_middle", "zero_last"} - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check that the data is loaded (should be empty arrays) - zero_dim_data = table["zero_dim"].to_pylist() + zero_dim_data = dataset["zero_dim"] assert len(zero_dim_data) == 3 # 3 rows assert all(len(row) == 0 for row in zero_dim_data) # Each row is empty # Check that shape info is lost - caplog.clear() - ds = load_dataset("hdf5", data_files=[hdf5_file_with_zero_dimensions], split="train") - assert all(isinstance(col, List) and col.length == -1 for col in ds.features.values()) + assert all(isinstance(col, List) and col.length == -1 for col in dataset.features.values()) # Check for the warnings assert ( @@ -522,13 +512,9 @@ def test_hdf5_zero_dimensions_handling(hdf5_file_with_zero_dimensions, caplog): ) -def test_hdf5_empty_file_warning(empty_hdf5_file, caplog): +def test_hdf5_empty_file_warning(empty_hdf5_file, hdf5_file_with_arrays, caplog): """Test that empty files (no datasets) are skipped with a warning.""" - hdf5 = HDF5() - generator = hdf5._generate_tables([[empty_hdf5_file]]) - - tables = list(generator) - assert len(tables) == 0 # No tables should be generated + load_dataset("hdf5", data_files=[hdf5_file_with_arrays, empty_hdf5_file], split="train") # Check that warning was logged assert any( @@ -538,20 +524,13 @@ def test_hdf5_empty_file_warning(empty_hdf5_file, caplog): def test_hdf5_feature_inference(hdf5_file_with_arrays): """Test automatic feature inference from HDF5 datasets.""" - data_files = DataFilesDict({"train": [hdf5_file_with_arrays]}) - config = HDF5Config(data_files=data_files) - hdf5 = HDF5() - hdf5.config = config - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_arrays], split="train") # Check that features were inferred - assert hdf5.info.features is not None + assert dataset.features is not None # Check specific feature types - features = hdf5.info.features + features = dataset.features # (n_rows, 3, 4) -> Array2D with shape (3, 4) assert isinstance(features["matrix_2d"], Array2D) assert features["matrix_2d"].shape == (3, 4) @@ -568,20 +547,13 @@ def test_hdf5_feature_inference(hdf5_file_with_arrays): def test_hdf5_vlen_feature_inference(hdf5_file_with_vlen_arrays): """Test automatic feature inference from variable-length HDF5 datasets.""" - data_files = DataFilesDict({"train": [hdf5_file_with_vlen_arrays]}) - config = HDF5Config(data_files=data_files) - hdf5 = HDF5() - hdf5.config = config - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_vlen_arrays], split="train") # Check that features were inferred - assert hdf5.info.features is not None + assert dataset.features is not None # Check specific feature types for variable-length arrays - features = hdf5.info.features + features = dataset.features # Variable-length arrays should become List features by default (for small datasets) assert isinstance(features["vlen_ints"], List) assert isinstance(features["mixed_data"], List) @@ -595,20 +567,13 @@ def test_hdf5_vlen_feature_inference(hdf5_file_with_vlen_arrays): def test_hdf5_variable_string_feature_inference(hdf5_file_with_variable_length_strings): """Test automatic feature inference from variable-length string datasets.""" - data_files = DataFilesDict({"train": [hdf5_file_with_variable_length_strings]}) - config = HDF5Config(data_files=data_files) - hdf5 = HDF5() - hdf5.config = config - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_variable_length_strings], split="train") # Check that features were inferred - assert hdf5.info.features is not None + assert dataset.features is not None # Check specific feature types for variable-length strings - features = hdf5.info.features + features = dataset.features # Variable-length strings should become Value("string") features assert isinstance(features["var_strings"], Value) assert isinstance(features["var_bytes"], Value) @@ -618,21 +583,16 @@ def test_hdf5_variable_string_feature_inference(hdf5_file_with_variable_length_s assert features["var_bytes"].dtype == "string" -def test_hdf5_columns_features_mismatch(): - """Test that mismatched columns and features raise an error.""" - features = Features({"col1": Value("int32"), "col2": Value("float32")}) +def test_hdf5_invalid_features(hdf5_file_with_arrays): + """Test that invalid features raise an error.""" + features = Features({"fakefeature": Value("int32")}) + with pytest.raises(ValueError): + load_dataset("hdf5", data_files=[hdf5_file_with_arrays], split="train", features=features) - config = HDF5Config( - name="test", - columns=["col1", "col3"], # col3 not in features - features=features, - ) - - hdf5 = HDF5() - hdf5.config = config - - with pytest.raises(ValueError, match="must contain the same columns"): - hdf5._info() + # try with one valid and one invalid feature + features = Features({"matrix_2d": Array2D(shape=(3, 4), dtype="float32"), "fakefeature": Value("int32")}) + with pytest.raises(DatasetGenerationError): + load_dataset("hdf5", data_files=[hdf5_file_with_arrays], split="train", features=features) def test_hdf5_no_data_files_error(): @@ -647,15 +607,7 @@ def test_hdf5_no_data_files_error(): def test_hdf5_complex_numbers(hdf5_file_with_complex_data): """Test HDF5 loading with complex number datasets.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - - generator = hdf5._generate_tables([[hdf5_file_with_complex_data]]) - tables = list(generator) - - assert len(tables) == 1 - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_complex_data], split="train") # Check that complex numbers are represented as nested Features expected_columns = { @@ -663,28 +615,34 @@ def test_hdf5_complex_numbers(hdf5_file_with_complex_data): "complex_128", "complex_array", } - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check complex_64 data - complex_64_data = table["complex_64"].to_pylist() + complex_64_data = dataset["complex_64"] assert len(complex_64_data) == 4 assert complex_64_data[0] == {"real": 1.0, "imag": 2.0} assert complex_64_data[1] == {"real": 3.0, "imag": 4.0} assert complex_64_data[2] == {"real": 5.0, "imag": 6.0} assert complex_64_data[3] == {"real": 7.0, "imag": 8.0} + assert np.asarray(dataset.data["complex_64"].flatten()[0]).dtype == np.float32 + assert np.asarray(dataset.data["complex_64"].flatten()[1]).dtype == np.float32 + assert (np.asarray(dataset.data["complex_64"].flatten()[0]) == np.array([1, 3, 5, 7], dtype=np.float32)).all() + assert (np.asarray(dataset.data["complex_64"].flatten()[1]) == np.array([2, 4, 6, 8], dtype=np.float32)).all() -def test_hdf5_compound_types(hdf5_file_with_compound_data): - """Test HDF5 loading with compound/structured datasets.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config + assert np.asarray(dataset.data["complex_128"].flatten()[0]).dtype == np.float64 + assert np.asarray(dataset.data["complex_128"].flatten()[1]).dtype == np.float64 + assert ( + np.asarray(dataset.data["complex_128"].flatten()[0]) == np.array([1.5, 3.5, 5.5, 7.5], dtype=np.float64) + ).all() + assert ( + np.asarray(dataset.data["complex_128"].flatten()[1]) == np.array([2.5, 4.5, 6.5, 8.5], dtype=np.float64) + ).all() - generator = hdf5._generate_tables([[hdf5_file_with_compound_data]]) - tables = list(generator) - assert len(tables) == 1 - _, table = tables[0] +def test_hdf5_compound_types(hdf5_file_with_compound_data): + """Test HDF5 loading with compound/structured datasets.""" + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_compound_data], split="train") # Check that compound types are represented as nested structures expected_columns = { @@ -692,10 +650,10 @@ def test_hdf5_compound_types(hdf5_file_with_compound_data): "complex_compound", "nested_compound", } - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check simple compound data - simple_compound_data = table["simple_compound"].to_pylist() + simple_compound_data = dataset["simple_compound"] assert len(simple_compound_data) == 3 assert simple_compound_data[0] == {"x": 1, "y": 2.5} assert simple_compound_data[1] == {"x": 3, "y": 4.5} @@ -704,59 +662,39 @@ def test_hdf5_compound_types(hdf5_file_with_compound_data): def test_hdf5_feature_inference_complex(hdf5_file_with_complex_data): """Test automatic feature inference for complex datasets.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_complex_data]}) - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_complex_data], split="train") # Check that features were inferred correctly - assert hdf5.info.features is not None - features = hdf5.info.features + assert dataset.features is not None + features = dataset.features # Check complex number features assert "complex_64" in features - assert isinstance(features["complex_64"], Features) - assert features["complex_64"]["real"] == Value("float64") - assert features["complex_64"]["imag"] == Value("float64") + # Complex features are represented as dict, not Features object + assert isinstance(features["complex_64"], dict) + assert features["complex_64"]["real"] == Value("float32") + assert features["complex_64"]["imag"] == Value("float32") def test_hdf5_feature_inference_compound(hdf5_file_with_compound_data): """Test automatic feature inference for compound datasets.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - hdf5.config.data_files = DataFilesDict({"train": [hdf5_file_with_compound_data]}) - - # Trigger feature inference - dl_manager = StreamingDownloadManager() - hdf5._split_generators(dl_manager) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_compound_data], split="train") # Check that features were inferred correctly - assert hdf5.info.features is not None - features = hdf5.info.features + assert dataset.features is not None + features = dataset.features # Check compound type features assert "simple_compound" in features - assert isinstance(features["simple_compound"], Features) + # Compound features are represented as dict, not Features object + assert isinstance(features["simple_compound"], dict) assert features["simple_compound"]["x"] == Value("int32") assert features["simple_compound"]["y"] == Value("float64") def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types): """Test HDF5 loading with mixed data types in the same file.""" - config = HDF5Config() - hdf5 = HDF5() - hdf5.config = config - - generator = hdf5._generate_tables([[hdf5_file_with_mixed_data_types]]) - tables = list(generator) - - assert len(tables) == 1 - _, table = tables[0] + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_mixed_data_types], split="train") # Check all expected columns are present expected_columns = { @@ -765,63 +703,127 @@ def test_hdf5_mixed_data_types(hdf5_file_with_mixed_data_types): "complex_data", "compound_data", } - assert set(table.column_names) == expected_columns + assert set(dataset.column_names) == expected_columns # Check data types - assert table["regular_int"].to_pylist() == [0, 1, 2] - assert len(table["complex_data"].to_pylist()) == 3 - assert len(table["compound_data"].to_pylist()) == 3 + assert dataset["regular_int"] == [0, 1, 2] + assert len(dataset["complex_data"]) == 3 + assert len(dataset["compound_data"]) == 3 def test_hdf5_mismatched_lengths_with_column_filtering(hdf5_file_with_mismatched_lengths): """Test that mismatched dataset lengths are ignored when the mismatched dataset is excluded via columns config.""" - config = HDF5Config(columns=["data1"]) - hdf5 = HDF5() - hdf5.config = config - - generator = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) - tables = list(generator) + # Test 1: Include only the first dataset + features = Features({"data1": Value("int32")}) + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_mismatched_lengths], split="train", features=features) # Should work without error since we're only including the first dataset - assert len(tables) == 1 - _, table = tables[0] - - # Check that only the specified column is present expected_columns = {"data1"} - assert set(table.column_names) == expected_columns - assert "data2" not in table.column_names + assert set(dataset.column_names) == expected_columns + assert "data2" not in dataset.column_names # Check the data - data1_values = table["data1"].to_pylist() + data1_values = dataset["data1"] assert data1_values == [0, 1, 2, 3, 4] # Test 2: Include multiple compatible datasets (all with 5 rows) - config2 = HDF5Config(columns=["data1", "data3", "data4", "data5", "data6"]) - hdf5.config = config2 - - generator2 = hdf5._generate_tables([[hdf5_file_with_mismatched_lengths]]) - tables2 = list(generator2) + features = Features( + { + "data1": Value("int32"), + "data3": Array2D(shape=(3, 4), dtype="float32"), + "data4": Value("float64"), + "data5": Value("bool"), + "data6": Value("string"), + } + ) + dataset2 = load_dataset("hdf5", data_files=[hdf5_file_with_mismatched_lengths], split="train", features=features) # Should work without error since we're excluding the mismatched dataset - assert len(tables2) == 1 - _, table2 = tables2[0] - - # Check that all specified columns are present expected_columns2 = {"data1", "data3", "data4", "data5", "data6"} - assert set(table2.column_names) == expected_columns2 - assert "data2" not in table2.column_names + assert set(dataset2.column_names) == expected_columns2 + assert "data2" not in dataset2.column_names # Check data types and values - assert table2["data1"].to_pylist() == [0, 1, 2, 3, 4] # int32 - assert len(table2["data3"].to_pylist()) == 5 # Array2D - assert len(table2["data3"].to_pylist()[0]) == 3 # 3 rows in each 2D array - assert len(table2["data3"].to_pylist()[0][0]) == 4 # 4 columns in each 2D array - np.testing.assert_allclose(table2["data4"].to_pylist(), [0.0, 0.1, 0.2, 0.3, 0.4], rtol=1e-6) # float64 - assert table2["data5"].to_pylist() == [True, False, True, False, True] # boolean - assert table2["data6"].to_pylist() == [ + assert dataset2["data1"] == [0, 1, 2, 3, 4] # int32 + assert len(dataset2["data3"]) == 5 # Array2D + assert len(dataset2["data3"][0]) == 3 # 3 rows in each 2D array + assert len(dataset2["data3"][0][0]) == 4 # 4 columns in each 2D array + np.testing.assert_allclose(dataset2["data4"], [0.0, 0.1, 0.2, 0.3, 0.4], rtol=1e-6) # float64 + assert dataset2["data5"] == [True, False, True, False, True] # boolean + assert dataset2["data6"] == [ "short", "medium length", "very long string", "tiny", "another string", ] # vlen string + + +def test_hdf5_compound_with_complex_arrays(hdf5_file_with_compound_complex_arrays): + """Test HDF5 loading with compound datasets containing complex arrays.""" + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_compound_complex_arrays], split="train") + + # Check that compound types with complex arrays are represented as nested structures + expected_columns = {"compound_with_complex"} + assert set(dataset.column_names) == expected_columns + + # Check compound data with complex arrays + compound_data = dataset["compound_with_complex"] + assert len(compound_data) == 3 + + # Check first row + first_row = compound_data[0] + assert first_row["position"]["x"] == 1 + assert first_row["position"]["y"] == 2 + + # Check complex field (should be represented as real/imag structure) + assert first_row["complex_field"]["real"] == 1.0 + assert first_row["complex_field"]["imag"] == 2.0 + + # Check complex array (should be represented as nested real/imag structures) + complex_array = first_row["complex_array"] + assert len(complex_array["real"]) == 2 # 2 rows + assert len(complex_array["real"][0]) == 3 # 3 columns + + # Check first element of complex array + assert complex_array["real"][0][0] == 1.0 + assert complex_array["imag"][0][0] == 2.0 + + # Check nested complex field + assert first_row["nested_complex"]["real"] == 1.5 + assert first_row["nested_complex"]["imag"] == 2.5 + + +def test_hdf5_feature_inference_compound_complex_arrays(hdf5_file_with_compound_complex_arrays): + """Test automatic feature inference for compound datasets with complex arrays.""" + dataset = load_dataset("hdf5", data_files=[hdf5_file_with_compound_complex_arrays], split="train") + + # Check that features were inferred correctly + assert dataset.features is not None + features = dataset.features + + # Check compound type features with complex arrays + assert "compound_with_complex" in features + + # Check nested structure + compound_features = features["compound_with_complex"] + assert "position" in compound_features + assert "complex_field" in compound_features + assert "complex_array" in compound_features + assert "nested_complex" in compound_features + + # Check position field (nested compound) + assert compound_features["position"]["x"] == Value("int32") + assert compound_features["position"]["y"] == Value("int32") + + # Check complex field (should be real/imag structure) + assert compound_features["complex_field"]["real"] == Value("float32") + assert compound_features["complex_field"]["imag"] == Value("float32") + + # Check complex array (should be nested real/imag structures) + assert compound_features["complex_array"]["real"] == Array2D(shape=(2, 3), dtype="float32") + assert compound_features["complex_array"]["imag"] == Array2D(shape=(2, 3), dtype="float32") + + # Check nested complex field + assert compound_features["nested_complex"]["real"] == Value("float32") + assert compound_features["nested_complex"]["imag"] == Value("float32")