diff --git a/gtfs_loader/__init__.py b/gtfs_loader/__init__.py index ed01ab2..ac96273 100644 --- a/gtfs_loader/__init__.py +++ b/gtfs_loader/__init__.py @@ -11,12 +11,19 @@ class ParseError(ValueError): pass +def get_files(files): + return schema.FileCollection(*(schema.GTFS_FILENAMES[f] for f in files)).values() -def load(gtfs_dir, sorted_read=False): + +def load(gtfs_dir, sorted_read=False, files=None, verbose=True): gtfs_dir = Path(gtfs_dir) gtfs = types.Entity() - for file_schema in schema.GTFS_SUBSET_SCHEMA.values(): - print(f'Loading {file_schema.name}') + + files_to_load = get_files(files) if files else schema.GTFS_SUBSET_SCHEMA.values() + + for file_schema in files_to_load: + if verbose: + print(f'Loading {file_schema.name}') filepath = gtfs_dir / file_schema.filename gtfs[file_schema.name] = types.EntityDict( file_schema.get_declared_fields()) @@ -210,7 +217,7 @@ def sorted_entities(file_schema, entities): return sorted(entities.items(), key=lambda kv: kv[0]) -def patch(gtfs, gtfs_in_dir, gtfs_out_dir, sorted_output=False): +def patch(gtfs, gtfs_in_dir, gtfs_out_dir, files=None, sorted_output=False, verbose=True): gtfs_in_dir = Path(gtfs_in_dir) gtfs_out_dir = Path(gtfs_out_dir) gtfs_out_dir.mkdir(parents=True, exist_ok=True) @@ -222,8 +229,11 @@ def patch(gtfs, gtfs_in_dir, gtfs_out_dir, sorted_output=False): except shutil.SameFileError: pass # No need to copy if we're working in-place - for file_schema in schema.GTFS_SUBSET_SCHEMA.values(): - print(f'Writing {file_schema.name}') + files_to_patch = get_files(files) if files else schema.GTFS_SUBSET_SCHEMA.values() + + for file_schema in files_to_patch: + if verbose: + print(f'Writing {file_schema.name}') entities = gtfs.get(file_schema.name) if not entities: (gtfs_out_dir / file_schema.filename).unlink(missing_ok=True) diff --git a/gtfs_loader/schema.py b/gtfs_loader/schema.py index 92c71db..b80d760 100644 --- a/gtfs_loader/schema.py +++ b/gtfs_loader/schema.py @@ -326,3 +326,17 @@ def route(self): GTFS_SUBSET_SCHEMA = FileCollection(Agency, BookingRule, Calendar, CalendarDate, Locations, LocationGroups, Routes, Transfer, Trip, Stop, StopTime) + +GTFS_FILENAMES = { + Agency._schema.name: Agency, + BookingRule._schema.name: BookingRule, + Calendar._schema.name: Calendar, + CalendarDate._schema.name: CalendarDate, + Locations._schema.name: Locations, + LocationGroups._schema.name: LocationGroups, + Routes._schema.name: Routes, + Transfer._schema.name: Transfer, + Trip._schema.name: Trip, + Stop._schema.name: Stop, + StopTime._schema.name: StopTime, +} diff --git a/setup.py b/setup.py index 19a96ea..0fe3478 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup(name='py-gtfs-loader', - version='0.1.12', + version='0.1.13', description='Load GTFS', url='https://github.com/TransitApp/py-gtfs-loader', author='Nicholas Paun, Jonathan Milot', diff --git a/tests/test_runner.py b/tests/test_runner.py index 7231616..f5f3f92 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -16,6 +16,6 @@ def test_default(feed_dir): def do_test(feed_dir): work_dir = test_support.create_test_data(feed_dir) - gtfs = gtfs_loader.load(work_dir) - gtfs_loader.patch(gtfs, work_dir, work_dir) + gtfs = gtfs_loader.load(work_dir, verbose=False) + gtfs_loader.patch(gtfs, work_dir, work_dir, verbose=False) test_support.check_expected_output(feed_dir, work_dir)