@@ -471,9 +471,10 @@ class StreamParser(object):
471
471
Stream parser.
472
472
"""
473
473
474
- def __init__ (self , seq_names , stream ):
474
+ def __init__ (self , seq_names , stream , use_lazy_data_integrity_checks = False ):
475
475
self .seq_names = seq_names
476
476
self .stream = stream
477
+ self .use_lazy_data_integrity_checks = use_lazy_data_integrity_checks
477
478
478
479
self .num_features = None
479
480
self .feature_type = None # 1 for sparse, 2 for dense
@@ -518,8 +519,10 @@ def __init__(self, *args, **kwargs):
518
519
if self .dtype is None :
519
520
self .dtype = str (seq_data .dtype )
520
521
521
- assert seq_data .shape [1 ] == self .num_features
522
- assert str (seq_data .dtype ) == self .dtype
522
+ if self .use_lazy_data_integrity_checks :
523
+ break
524
+
525
+ self .check_data_integrity (seq_data , s )
523
526
524
527
self .feature_type = 2
525
528
@@ -528,7 +531,12 @@ def get_data(self, seq_name):
528
531
:param str seq_name:
529
532
:rtype: numpy.ndarray
530
533
"""
531
- return self .stream ["data" ][seq_name ][...]
534
+ data = self .stream ["data" ][seq_name ][...]
535
+
536
+ if self .use_lazy_data_integrity_checks :
537
+ self .check_data_integrity (data , seq_name )
538
+
539
+ return data
532
540
533
541
def get_seq_length (self , seq_name ):
534
542
"""
@@ -537,6 +545,18 @@ def get_seq_length(self, seq_name):
537
545
"""
538
546
return self .stream ["data" ][seq_name ].shape [0 ]
539
547
548
+ def check_data_integrity (self , data , seq_name ):
549
+ """
550
+ :param numpy.ndarray data
551
+ :param str seq_name
552
+ """
553
+
554
+ assert len (data .shape ) == 2 , f"shape length mismatch in { seq_name } : { data .shape } (should be 2-dimensional)"
555
+ assert (
556
+ self .num_features == data .shape [1 ]
557
+ ), f"feature dim mismatch in { seq_name } : { data .shape [1 ]} (should be { self .num_features } )"
558
+ assert self .dtype == str (data .dtype ), f"dtype mismatch { seq_name } : { str (data .dtype )} (should be { self .dtype } )"
559
+
540
560
541
561
class SparseStreamParser (StreamParser ):
542
562
"""
@@ -552,7 +572,11 @@ def __init__(self, *args, **kwargs):
552
572
553
573
if self .dtype is None :
554
574
self .dtype = str (seq_data .dtype )
555
- assert str (seq_data .dtype ) == self .dtype
575
+
576
+ if self .use_lazy_data_integrity_checks :
577
+ break
578
+
579
+ self .check_data_integrity (seq_data , s )
556
580
557
581
self .num_features = self .stream ["feature_names" ].shape [0 ]
558
582
self .feature_type = 1
@@ -562,7 +586,12 @@ def get_data(self, seq_name):
562
586
:param str seq_name:
563
587
:rtype: numpy.ndarray
564
588
"""
565
- return self .stream ["data" ][seq_name ][:]
589
+ data = self .stream ["data" ][seq_name ][:]
590
+
591
+ if self .use_lazy_data_integrity_checks :
592
+ self .check_data_integrity (data , seq_name )
593
+
594
+ return data
566
595
567
596
def get_seq_length (self , seq_name ):
568
597
"""
@@ -571,6 +600,17 @@ def get_seq_length(self, seq_name):
571
600
"""
572
601
return self .stream ["data" ][seq_name ].shape [0 ]
573
602
603
+ def check_data_integrity (self , data , seq_name ):
604
+ """
605
+ :param numpy.ndarray data
606
+ :param str seq_name
607
+ """
608
+
609
+ assert len (data .shape ) == 1 , f"shape length mismatch in { seq_name } : { data .shape } (should be 2-dimensional)"
610
+ assert self .dtype == str (
611
+ data .dtype
612
+ ), f"dtype mismatch in { seq_name } : { str (data .dtype )} (should be { self .dtype } )"
613
+
574
614
575
615
class SegmentAlignmentStreamParser (StreamParser ):
576
616
"""
@@ -585,10 +625,11 @@ def __init__(self, *args, **kwargs):
585
625
586
626
if self .dtype is None :
587
627
self .dtype = str (seq_data .dtype )
588
- assert str (seq_data .dtype ) == self .dtype
589
628
590
- assert len (seq_data .shape ) == 2
591
- assert seq_data .shape [1 ] == 2
629
+ if self .use_lazy_data_integrity_checks :
630
+ break
631
+
632
+ self .check_data_integrity (seq_data , s )
592
633
593
634
self .num_features = self .stream ["feature_names" ].shape [0 ]
594
635
self .feature_type = 1
@@ -602,6 +643,9 @@ def get_data(self, seq_name):
602
643
length = self .get_seq_length (seq_name ) // 2
603
644
segments = self .stream ["data" ][seq_name ][:]
604
645
646
+ if self .use_lazy_data_integrity_checks :
647
+ self .check_data_integrity (segments , seq_name )
648
+
605
649
alignment = numpy .zeros ((length , 2 ), dtype = self .dtype )
606
650
num_segments = segments .shape [0 ]
607
651
seg_end = 0
@@ -621,6 +665,22 @@ def get_seq_length(self, seq_name):
621
665
"""
622
666
return 2 * sum (self .stream ["data" ][seq_name ][:, 1 ])
623
667
668
+ def check_data_integrity (self , data , seq_name ):
669
+ """
670
+ :param numpy.ndarray data
671
+ :param str seq_name
672
+ """
673
+
674
+ assert (
675
+ len (data .shape ) == 2
676
+ ), f"shape length mismatch in { seq_name } : { data .shape } (should be 2-dimensional)"
677
+ assert (
678
+ data .shape [1 ] == 2
679
+ ), f"feature dim mismatch in { seq_name } : { data .shape [1 ]} (should be 2-dimensional)"
680
+ assert self .dtype == str (
681
+ data .dtype
682
+ ), f"dtype mismatch in { seq_name } : { str (data .dtype )} (should be { self .dtype } )"
683
+
624
684
625
685
class NextGenHDFDataset (CachedDataset2 ):
626
686
"""
@@ -633,7 +693,7 @@ class NextGenHDFDataset(CachedDataset2):
633
693
"segment_alignment" : SegmentAlignmentStreamParser ,
634
694
}
635
695
636
- def __init__ (self , input_stream_name , files = None , ** kwargs ):
696
+ def __init__ (self , input_stream_name , files = None , use_lazy_data_integrity_checks = False , ** kwargs ):
637
697
"""
638
698
:param str input_stream_name:
639
699
:param None|list[str] files:
@@ -649,6 +709,7 @@ def __init__(self, input_stream_name, files=None, **kwargs):
649
709
self .file_indices = []
650
710
self .seq_order = []
651
711
self .all_parsers = collections .defaultdict (list )
712
+ self .use_lazy_data_integrity_checks = use_lazy_data_integrity_checks
652
713
653
714
if files :
654
715
for fn in files :
@@ -684,7 +745,9 @@ def add_file(self, path):
684
745
)
685
746
686
747
parsers = {
687
- name : NextGenHDFDataset .parsers [stream .attrs ["parser" ]](norm_seqs , stream )
748
+ name : NextGenHDFDataset .parsers [stream .attrs ["parser" ]](
749
+ norm_seqs , stream , use_lazy_data_integrity_checks = self .use_lazy_data_integrity_checks
750
+ )
688
751
for name , stream in cur_file ["streams" ].items ()
689
752
}
690
753
for k , v in parsers .items ():
@@ -807,7 +870,15 @@ class SiameseHDFDataset(CachedDataset2):
807
870
"segment_alignment" : SegmentAlignmentStreamParser ,
808
871
}
809
872
810
- def __init__ (self , input_stream_name , seq_label_stream = "words" , class_distribution = None , files = None , ** kwargs ):
873
+ def __init__ (
874
+ self ,
875
+ input_stream_name ,
876
+ seq_label_stream = "words" ,
877
+ class_distribution = None ,
878
+ files = None ,
879
+ use_lazy_data_integrity_checks = False ,
880
+ ** kwargs ,
881
+ ):
811
882
"""
812
883
:param str input_stream_name: name of a feature stream
813
884
:param str seq_label_stream: name of a stream with labels
@@ -833,6 +904,8 @@ def __init__(self, input_stream_name, seq_label_stream="words", class_distributi
833
904
self .target_to_seqs = {} # (int) class_index -> (string) sequence_names
834
905
self .curr_epoch_triplets = []
835
906
self .targets_stream = seq_label_stream
907
+ self .use_lazy_data_integrity_checks = use_lazy_data_integrity_checks
908
+
836
909
if files :
837
910
for fn in files :
838
911
self .add_file (fn )
@@ -872,7 +945,9 @@ def add_file(self, path):
872
945
)
873
946
874
947
parsers = {
875
- name : SiameseHDFDataset .parsers [stream .attrs ["parser" ]](norm_seqs , stream )
948
+ name : SiameseHDFDataset .parsers [stream .attrs ["parser" ]](
949
+ norm_seqs , stream , use_lazy_data_integrity_checks = self .use_lazy_data_integrity_checks
950
+ )
876
951
for name , stream in cur_file ["streams" ].items ()
877
952
} # name - stream name (words, features, orth_features)
878
953
for k , v in parsers .items ():
0 commit comments