@@ -900,8 +900,12 @@ def __init__(self, path):
900900 ]
901901 # Allow us to find which partition a given record is in
902902 self .partition_record_index = np .cumsum ([0 , * partition_num_records ])
903+ self .gt_field = None
903904 for field in self .metadata .fields :
904905 self .fields [field .full_name ] = IntermediateColumnarFormatField (self , field )
906+ if field .name == "GT" :
907+ self .gt_field = field
908+
905909 logger .info (
906910 f"Loaded IntermediateColumnarFormat(partitions={ self .num_partitions } , "
907911 f"records={ self .num_records } , fields={ self .num_fields } )"
@@ -970,19 +974,6 @@ def root_attrs(self):
970974 "vcf_header" : self .vcf_header ,
971975 }
972976
973- def iter_alleles (self , start , stop , num_alleles ):
974- ref_field = self .fields ["REF" ]
975- alt_field = self .fields ["ALT" ]
976-
977- for ref , alt in zip (
978- ref_field .iter_values (start , stop ),
979- alt_field .iter_values (start , stop ),
980- ):
981- alleles = np .full (num_alleles , constants .STR_FILL , dtype = "O" )
982- alleles [0 ] = ref [0 ]
983- alleles [1 : 1 + len (alt )] = alt
984- yield alleles
985-
986977 def iter_id (self , start , stop ):
987978 for value in self .fields ["ID" ].iter_values (start , stop ):
988979 if value is not None :
@@ -1025,6 +1016,19 @@ def iter_field(self, field_name, shape, start, stop):
10251016 for value in source_field .iter_values (start , stop ):
10261017 yield sanitiser (value )
10271018
1019+ def iter_alleles (self , start , stop , num_alleles ):
1020+ ref_field = self .fields ["REF" ]
1021+ alt_field = self .fields ["ALT" ]
1022+
1023+ for ref , alt in zip (
1024+ ref_field .iter_values (start , stop ),
1025+ alt_field .iter_values (start , stop ),
1026+ ):
1027+ alleles = np .full (num_alleles , constants .STR_FILL , dtype = "O" )
1028+ alleles [0 ] = ref [0 ]
1029+ alleles [1 : 1 + len (alt )] = alt
1030+ yield alleles
1031+
10281032 def iter_genotypes (self , shape , start , stop ):
10291033 source_field = self .fields ["FORMAT/GT" ]
10301034 for value in source_field .iter_values (start , stop ):
@@ -1034,6 +1038,16 @@ def iter_genotypes(self, shape, start, stop):
10341038 sanitised_phased = sanitise_value_int_1d (shape [:- 1 ], phased )
10351039 yield sanitised_genotypes , sanitised_phased
10361040
1041+ def iter_alleles_and_genotypes (self , start , stop , shape , num_alleles ):
1042+ if self .gt_field is None or shape is None :
1043+ for alleles in self .iter_alleles (start , stop , num_alleles ):
1044+ yield alleles , (None , None )
1045+ else :
1046+ yield from zip (
1047+ self .iter_alleles (start , stop , num_alleles ),
1048+ self .iter_genotypes (shape , start , stop ),
1049+ )
1050+
10371051 def generate_schema (
10381052 self , variants_chunk_size = None , samples_chunk_size = None , local_alleles = None
10391053 ):
@@ -1128,15 +1142,13 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11281142 [spec_from_field (field ) for field in self .metadata .info_fields ]
11291143 )
11301144
1131- gt_field = None
11321145 for field in self .metadata .format_fields :
11331146 if field .name == "GT" :
1134- gt_field = field
11351147 continue
11361148 array_specs .append (spec_from_field (field ))
11371149
1138- if gt_field is not None and n > 0 :
1139- ploidy = max (gt_field .summary .max_number - 1 , 1 )
1150+ if self . gt_field is not None and n > 0 :
1151+ ploidy = max (self . gt_field .summary .max_number - 1 , 1 )
11401152 # Add ploidy dimension only when needed
11411153 schema_instance .dimensions ["ploidy" ] = vcz .VcfZarrDimension (size = ploidy )
11421154
@@ -1152,7 +1164,7 @@ def fixed_field_spec(name, dtype, source=None, dimensions=("variants",)):
11521164 array_specs .append (
11531165 vcz .ZarrArraySpec (
11541166 name = "call_genotype" ,
1155- dtype = gt_field .smallest_dtype (),
1167+ dtype = self . gt_field .smallest_dtype (),
11561168 dimensions = ["variants" , "samples" , "ploidy" ],
11571169 description = "" ,
11581170 compressor = vcz .DEFAULT_ZARR_COMPRESSOR_GENOTYPES .get_config (),
0 commit comments