From 07672625ca6a4fb66704406adf62a77e216e51e4 Mon Sep 17 00:00:00 2001 From: Sanjay Kumar Sakamuri Kamalakar Date: Sat, 13 Sep 2025 23:03:26 +0530 Subject: [PATCH] fix: correct float feature generation in generate_examples --- benchmarks/utils.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/benchmarks/utils.py b/benchmarks/utils.py index feb13d9c8fa..c55a543eae5 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -1,7 +1,5 @@ import timeit - import numpy as np - import datasets from datasets.arrow_writer import ArrowWriter from datasets.features.features import _ArrayXD @@ -15,28 +13,41 @@ def wrapper(*args, **kwargs): return delta wrapper.__name__ = func.__name__ - return wrapper def generate_examples(features: dict, num_examples=100, seq_shapes=None): dummy_data = [] seq_shapes = seq_shapes or {} + for i in range(num_examples): example = {} for col_id, (k, v) in enumerate(features.items()): if isinstance(v, _ArrayXD): data = np.random.rand(*v.shape).astype(v.dtype) + elif isinstance(v, datasets.Value): if v.dtype == "string": data = "The small grey turtle was surprisingly fast when challenged." + elif "int" in v.dtype: + data = np.random.randint(0, 10, size=1).astype(v.dtype).item() + elif "float" in v.dtype: + data = np.random.rand(1).astype(v.dtype).item() else: - data = np.random.randint(10, size=1).astype(v.dtype).item() + raise TypeError(f"Unsupported dtype: {v.dtype}") + elif isinstance(v, datasets.Sequence): - while isinstance(v, datasets.Sequence): - v = v.feature - shape = seq_shapes[k] - data = np.random.rand(*shape).astype(v.dtype) + feature = v + while isinstance(feature, datasets.Sequence): + feature = feature.feature + shape = seq_shapes.get(k) + if shape is None: + raise ValueError(f"Shape for sequence feature '{k}' not provided in seq_shapes.") + data = np.random.rand(*shape).astype(feature.dtype) + + else: + raise TypeError(f"Unsupported feature type for key '{k}': {type(v)}") + example[k] = data dummy_data.append((i, example)) @@ -54,11 +65,14 @@ def generate_example_dataset(dataset_path, features, num_examples=100, seq_shape num_final_examples, num_bytes = writer.finalize() - if not num_final_examples == num_examples: + if num_final_examples != num_examples: raise ValueError( f"Error writing the dataset, wrote {num_final_examples} examples but should have written {num_examples}." ) - dataset = datasets.Dataset.from_file(filename=dataset_path, info=datasets.DatasetInfo(features=features)) + dataset = datasets.Dataset.from_file( + filename=dataset_path, + info=datasets.DatasetInfo(features=features) + ) return dataset