Skip to content

Commit a00d0ec

Browse files
committed
feat: Allow postponing dataset integrity checks to training time
1 parent e9a9920 commit a00d0ec

File tree

1 file changed

+88
-13
lines changed

1 file changed

+88
-13
lines changed

returnn/datasets/hdf.py

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -471,9 +471,10 @@ class StreamParser(object):
471471
Stream parser.
472472
"""
473473

474-
def __init__(self, seq_names, stream):
474+
def __init__(self, seq_names, stream, use_lazy_data_integrity_checks=False):
475475
self.seq_names = seq_names
476476
self.stream = stream
477+
self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks
477478

478479
self.num_features = None
479480
self.feature_type = None # 1 for sparse, 2 for dense
@@ -518,8 +519,10 @@ def __init__(self, *args, **kwargs):
518519
if self.dtype is None:
519520
self.dtype = str(seq_data.dtype)
520521

521-
assert seq_data.shape[1] == self.num_features
522-
assert str(seq_data.dtype) == self.dtype
522+
if self.use_lazy_data_integrity_checks:
523+
break
524+
525+
self.check_data_integrity(seq_data, s)
523526

524527
self.feature_type = 2
525528

@@ -528,7 +531,12 @@ def get_data(self, seq_name):
528531
:param str seq_name:
529532
:rtype: numpy.ndarray
530533
"""
531-
return self.stream["data"][seq_name][...]
534+
data = self.stream["data"][seq_name][...]
535+
536+
if self.use_lazy_data_integrity_checks:
537+
self.check_data_integrity(data, seq_name)
538+
539+
return data
532540

533541
def get_seq_length(self, seq_name):
534542
"""
@@ -537,6 +545,18 @@ def get_seq_length(self, seq_name):
537545
"""
538546
return self.stream["data"][seq_name].shape[0]
539547

548+
def check_data_integrity(self, data, seq_name):
549+
"""
550+
:param numpy.ndarray data
551+
:param str seq_name
552+
"""
553+
554+
assert len(data.shape) == 2, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)"
555+
assert (
556+
self.num_features == data.shape[1]
557+
), f"feature dim mismatch in {seq_name}: {data.shape[1]} (should be {self.num_features})"
558+
assert self.dtype == str(data.dtype), f"dtype mismatch {seq_name}: {str(data.dtype)} (should be {self.dtype})"
559+
540560

541561
class SparseStreamParser(StreamParser):
542562
"""
@@ -552,7 +572,11 @@ def __init__(self, *args, **kwargs):
552572

553573
if self.dtype is None:
554574
self.dtype = str(seq_data.dtype)
555-
assert str(seq_data.dtype) == self.dtype
575+
576+
if self.use_lazy_data_integrity_checks:
577+
break
578+
579+
self.check_data_integrity(seq_data, s)
556580

557581
self.num_features = self.stream["feature_names"].shape[0]
558582
self.feature_type = 1
@@ -562,7 +586,12 @@ def get_data(self, seq_name):
562586
:param str seq_name:
563587
:rtype: numpy.ndarray
564588
"""
565-
return self.stream["data"][seq_name][:]
589+
data = self.stream["data"][seq_name][:]
590+
591+
if self.use_lazy_data_integrity_checks:
592+
self.check_data_integrity(data, seq_name)
593+
594+
return data
566595

567596
def get_seq_length(self, seq_name):
568597
"""
@@ -571,6 +600,17 @@ def get_seq_length(self, seq_name):
571600
"""
572601
return self.stream["data"][seq_name].shape[0]
573602

603+
def check_data_integrity(self, data, seq_name):
604+
"""
605+
:param numpy.ndarray data
606+
:param str seq_name
607+
"""
608+
609+
assert len(data.shape) == 1, f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)"
610+
assert self.dtype == str(
611+
data.dtype
612+
), f"dtype mismatch in {seq_name}: {str(data.dtype)} (should be {self.dtype})"
613+
574614

575615
class SegmentAlignmentStreamParser(StreamParser):
576616
"""
@@ -585,10 +625,11 @@ def __init__(self, *args, **kwargs):
585625

586626
if self.dtype is None:
587627
self.dtype = str(seq_data.dtype)
588-
assert str(seq_data.dtype) == self.dtype
589628

590-
assert len(seq_data.shape) == 2
591-
assert seq_data.shape[1] == 2
629+
if self.use_lazy_data_integrity_checks:
630+
break
631+
632+
self.check_data_integrity(seq_data, s)
592633

593634
self.num_features = self.stream["feature_names"].shape[0]
594635
self.feature_type = 1
@@ -602,6 +643,9 @@ def get_data(self, seq_name):
602643
length = self.get_seq_length(seq_name) // 2
603644
segments = self.stream["data"][seq_name][:]
604645

646+
if self.use_lazy_data_integrity_checks:
647+
self.check_data_integrity(segments, seq_name)
648+
605649
alignment = numpy.zeros((length, 2), dtype=self.dtype)
606650
num_segments = segments.shape[0]
607651
seg_end = 0
@@ -621,6 +665,22 @@ def get_seq_length(self, seq_name):
621665
"""
622666
return 2 * sum(self.stream["data"][seq_name][:, 1])
623667

668+
def check_data_integrity(self, data, seq_name):
669+
"""
670+
:param numpy.ndarray data
671+
:param str seq_name
672+
"""
673+
674+
assert (
675+
len(data.shape) == 2
676+
), f"shape length mismatch in {seq_name}: {data.shape} (should be 2-dimensional)"
677+
assert (
678+
data.shape[1] == 2
679+
), f"feature dim mismatch in {seq_name}: {data.shape[1]} (should be 2-dimensional)"
680+
assert self.dtype == str(
681+
data.dtype
682+
), f"dtype mismatch in {seq_name}: {str(data.dtype)} (should be {self.dtype})"
683+
624684

625685
class NextGenHDFDataset(CachedDataset2):
626686
"""
@@ -633,7 +693,7 @@ class NextGenHDFDataset(CachedDataset2):
633693
"segment_alignment": SegmentAlignmentStreamParser,
634694
}
635695

636-
def __init__(self, input_stream_name, files=None, **kwargs):
696+
def __init__(self, input_stream_name, files=None, use_lazy_data_integrity_checks=False, **kwargs):
637697
"""
638698
:param str input_stream_name:
639699
:param None|list[str] files:
@@ -649,6 +709,7 @@ def __init__(self, input_stream_name, files=None, **kwargs):
649709
self.file_indices = []
650710
self.seq_order = []
651711
self.all_parsers = collections.defaultdict(list)
712+
self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks
652713

653714
if files:
654715
for fn in files:
@@ -684,7 +745,9 @@ def add_file(self, path):
684745
)
685746

686747
parsers = {
687-
name: NextGenHDFDataset.parsers[stream.attrs["parser"]](norm_seqs, stream)
748+
name: NextGenHDFDataset.parsers[stream.attrs["parser"]](
749+
norm_seqs, stream, use_lazy_data_integrity_checks=self.use_lazy_data_integrity_checks
750+
)
688751
for name, stream in cur_file["streams"].items()
689752
}
690753
for k, v in parsers.items():
@@ -807,7 +870,15 @@ class SiameseHDFDataset(CachedDataset2):
807870
"segment_alignment": SegmentAlignmentStreamParser,
808871
}
809872

810-
def __init__(self, input_stream_name, seq_label_stream="words", class_distribution=None, files=None, **kwargs):
873+
def __init__(
874+
self,
875+
input_stream_name,
876+
seq_label_stream="words",
877+
class_distribution=None,
878+
files=None,
879+
use_lazy_data_integrity_checks=False,
880+
**kwargs,
881+
):
811882
"""
812883
:param str input_stream_name: name of a feature stream
813884
:param str seq_label_stream: name of a stream with labels
@@ -833,6 +904,8 @@ def __init__(self, input_stream_name, seq_label_stream="words", class_distributi
833904
self.target_to_seqs = {} # (int) class_index -> (string) sequence_names
834905
self.curr_epoch_triplets = []
835906
self.targets_stream = seq_label_stream
907+
self.use_lazy_data_integrity_checks = use_lazy_data_integrity_checks
908+
836909
if files:
837910
for fn in files:
838911
self.add_file(fn)
@@ -872,7 +945,9 @@ def add_file(self, path):
872945
)
873946

874947
parsers = {
875-
name: SiameseHDFDataset.parsers[stream.attrs["parser"]](norm_seqs, stream)
948+
name: SiameseHDFDataset.parsers[stream.attrs["parser"]](
949+
norm_seqs, stream, use_lazy_data_integrity_checks=self.use_lazy_data_integrity_checks
950+
)
876951
for name, stream in cur_file["streams"].items()
877952
} # name - stream name (words, features, orth_features)
878953
for k, v in parsers.items():

0 commit comments

Comments
 (0)