Skip to content
2 changes: 1 addition & 1 deletion config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ val_initial: False

loader_num_workers: 8
log_validation: 0
analysis_streams_output: ["ERA5"]
streams_output: ["ERA5"]

istep: 0
run_history: []
Expand Down
2 changes: 1 addition & 1 deletion packages/common/src/weathergen/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class OutputBatchData:
# fstep, stream, redundant dim (size 1)
targets_lens: list[list[list[int]]]

# stream name: index into data (only streams in analysis_streams_output)
# stream name: index into data (only streams in streams_output)
streams: dict[str, int]

# stream, channel name
Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def inference_from_args(argl: list[str]):
end_date_val=args.end_date,
samples_per_validation=args.samples,
log_validation=args.samples if args.save_samples else 0,
analysis_streams_output=args.analysis_streams_output,
streams_output=args.streams_output,
)

cli_overwrite = config.from_cli_arglist(args.options)
Expand Down
5 changes: 2 additions & 3 deletions src/weathergen/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ def get_inference_parser() -> argparse.ArgumentParser:
help="Toggle saving of samples from inference. Default True",
)
parser.add_argument(
"--analysis_streams_output",
"--streams_output",
nargs="+",
default=["ERA5"],
help="Analysis output streams during inference.",
help="Output streams during inference.",
)

return parser
Expand Down
8 changes: 7 additions & 1 deletion src/weathergen/utils/validation_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ def write_output(
targets_lens,
):
stream_names = [stream.name for stream in cf.streams]
output_stream_names = cf.analysis_streams_output
if cf.streams_output is not None:
output_stream_names = cf.streams_output
elif cf.analysis_streams_output is not None: # --- to be removed at some point ---
output_stream_names = cf.analysis_streams_output # --- to be removed at some point ---
else:
output_stream_names = None

if output_stream_names is None:
output_stream_names = stream_names

Expand Down
12 changes: 6 additions & 6 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ def test_model_loading_has_params(parser):


@pytest.mark.parametrize("streams", [["ERA5", "FOO"], ["BAR"]])
def test_inference_analysis_streams_output(inference_parser, streams):
arglist = BASIC_ARGLIST + ["--analysis_streams_output", *streams]
def test_inference_streams_output(inference_parser, streams):
arglist = BASIC_ARGLIST + ["--streams_output", *streams]
args = inference_parser.parse_args(arglist)

assert args.analysis_streams_output == streams
assert args.streams_output == streams


def test_inference_analysis_streams_output_empty(inference_parser):
arglist = BASIC_ARGLIST + ["--analysis_streams_output", *[]]
def test_inference_streams_output_empty(inference_parser):
arglist = BASIC_ARGLIST + ["--streams_output", *[]]

with pytest.raises(SystemExit):
inference_parser.parse_args(arglist)
Expand All @@ -79,7 +79,7 @@ def test_inference_defaults(inference_parser):
"start_date",
"end_date",
"samples",
"analysis_streams_output",
"streams_output",
"epoch",
"private_config",
]
Expand Down
Loading