diff --git a/config/default_config.yml b/config/default_config.yml index 620f5c4ae..72c5322b4 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -147,7 +147,7 @@ val_initial: False loader_num_workers: 8 log_validation: 0 -analysis_streams_output: ["ERA5"] +streams_output: ["ERA5"] istep: 0 run_history: [] diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 3e7594d1c..428cd5d47 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -355,7 +355,7 @@ class OutputBatchData: # fstep, stream, redundant dim (size 1) targets_lens: list[list[list[int]]] - # stream name: index into data (only streams in analysis_streams_output) + # stream name: index into data (only streams in streams_output) streams: dict[str, int] # stream, channel name diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index eb2cab895..4a650c778 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -47,7 +47,7 @@ def inference_from_args(argl: list[str]): end_date_val=args.end_date, samples_per_validation=args.samples, log_validation=args.samples if args.save_samples else 0, - analysis_streams_output=args.analysis_streams_output, + streams_output=args.streams_output, ) cli_overwrite = config.from_cli_arglist(args.options) diff --git a/src/weathergen/utils/cli.py b/src/weathergen/utils/cli.py index 9e7b8f562..ad0641c27 100644 --- a/src/weathergen/utils/cli.py +++ b/src/weathergen/utils/cli.py @@ -59,10 +59,9 @@ def get_inference_parser() -> argparse.ArgumentParser: help="Toggle saving of samples from inference. Default True", ) parser.add_argument( - "--analysis_streams_output", + "--streams_output", nargs="+", - default=["ERA5"], - help="Analysis output streams during inference.", + help="Output streams during inference.", ) return parser diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index e28563132..fe12c1106 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -27,7 +27,13 @@ def write_output( targets_lens, ): stream_names = [stream.name for stream in cf.streams] - output_stream_names = cf.analysis_streams_output + if cf.streams_output is not None: + output_stream_names = cf.streams_output + elif cf.analysis_streams_output is not None: # --- to be removed at some point --- + output_stream_names = cf.analysis_streams_output # --- to be removed at some point --- + else: + output_stream_names = None + if output_stream_names is None: output_stream_names = stream_names diff --git a/tests/test_cli.py b/tests/test_cli.py index e39535c31..9b0623fcd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -60,15 +60,15 @@ def test_model_loading_has_params(parser): @pytest.mark.parametrize("streams", [["ERA5", "FOO"], ["BAR"]]) -def test_inference_analysis_streams_output(inference_parser, streams): - arglist = BASIC_ARGLIST + ["--analysis_streams_output", *streams] +def test_inference_streams_output(inference_parser, streams): + arglist = BASIC_ARGLIST + ["--streams_output", *streams] args = inference_parser.parse_args(arglist) - assert args.analysis_streams_output == streams + assert args.streams_output == streams -def test_inference_analysis_streams_output_empty(inference_parser): - arglist = BASIC_ARGLIST + ["--analysis_streams_output", *[]] +def test_inference_streams_output_empty(inference_parser): + arglist = BASIC_ARGLIST + ["--streams_output", *[]] with pytest.raises(SystemExit): inference_parser.parse_args(arglist) @@ -79,7 +79,7 @@ def test_inference_defaults(inference_parser): "start_date", "end_date", "samples", - "analysis_streams_output", + "streams_output", "epoch", "private_config", ]