1717
1818from  __future__ import  annotations 
1919
20+ from  collections .abc  import  Mapping 
2021import  os 
2122
2223import  numpy  as  np 
@@ -44,13 +45,27 @@ def _info(self):
4445        homepage = 'https://github.com/deepmind/pg19' ,
4546    )
4647
48+   def  _get_paths (self , data_dir : str ) ->  Mapping [str , str ]:
49+     return  {
50+         'metadata' : os .path .join (data_dir , 'metadata.csv' ),
51+         'train' : os .path .join (data_dir , 'train' ),
52+         'validation' : os .path .join (data_dir , 'validation' ),
53+         'test' : os .path .join (data_dir , 'test' ),
54+     }
55+ 
4756  def  _split_generators (self , dl_manager ):
4857    """Returns SplitGenerators.""" 
4958    del  dl_manager   # Unused 
5059
5160    metadata_dict  =  dict ()
52-     metadata_path  =  os .path .join (_DATA_DIR , 'metadata.csv' )
53-     metadata  =  tf .io .gfile .GFile (metadata_path ).read ().splitlines ()
61+     if  self .data_dir  and  all (
62+         map (os .path .exists , self ._get_paths (self .data_dir ).values ())
63+     ):
64+       data_dir  =  self ._data_dir 
65+     else :
66+       data_dir  =  _DATA_DIR 
67+     paths  =  self ._get_paths (data_dir )
68+     metadata  =  tf .io .gfile .GFile (paths ['metadata' ]).read ().splitlines ()
5469
5570    for  row  in  metadata :
5671      row_split  =  row .split (',' )
@@ -62,21 +77,21 @@ def _split_generators(self, dl_manager):
6277            name = tfds .Split .TRAIN ,
6378            gen_kwargs = {
6479                'metadata' : metadata_dict ,
65-                 'filepath' : os . path . join ( _DATA_DIR ,  'train' ) ,
80+                 'filepath' : paths [ 'train' ] ,
6681            },
6782        ),
6883        tfds .core .SplitGenerator (
6984            name = tfds .Split .VALIDATION ,
7085            gen_kwargs = {
7186                'metadata' : metadata_dict ,
72-                 'filepath' : os . path . join ( _DATA_DIR ,  'validation' ) ,
87+                 'filepath' : paths [ 'validation' ] ,
7388            },
7489        ),
7590        tfds .core .SplitGenerator (
7691            name = tfds .Split .TEST ,
7792            gen_kwargs = {
7893                'metadata' : metadata_dict ,
79-                 'filepath' : os . path . join ( _DATA_DIR ,  'test' ) ,
94+                 'filepath' : paths [ 'test' ] ,
8095            },
8196        ),
8297    ]
0 commit comments