From 1fd00a509955dac01f8f098f3b13e5ae7ba53265 Mon Sep 17 00:00:00 2001 From: Malo OLIVIER Date: Fri, 13 Dec 2024 11:03:12 +0100 Subject: [PATCH] Add resolution_range parameter and update tests in generate_hnet_training_data.py - Introduce resolution_range to handle standard, fine, and coarse resolutions - Modify generate_data function to set list_resolutions based on resolution_range and data_type - Update output filename to include resolution_range - Add default_sample_range variable - Update main function to accept and pass resolution_range and testing parameters - Enhance tests by parameterizing resolution_range with various resolution options --- generate_hnet_training_data.py | 52 +++++++++++++++---- ...t_scenarios_generate_hnet_training_data.py | 12 ++++- 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/generate_hnet_training_data.py b/generate_hnet_training_data.py index 84ed9f8..7ee09f4 100644 --- a/generate_hnet_training_data.py +++ b/generate_hnet_training_data.py @@ -11,6 +11,7 @@ from scipy.optimize import linear_sum_assignment from scipy.spatial import distance +default_sample_range = np.array([3000, 5000, 15000]) def sph2cart(azimuth, elevation, r): """ @@ -56,15 +57,17 @@ def compute_class_imbalance(data_dict): return class_counts -def generate_data(pickle_filename, max_doas, sample_range, data_type="train"): +def generate_data( + max_doas, sample_range, data_type="train", resolution_range="standard_resolution" +): """ Generates training or testing data based on the specified parameters. Args: - pickle_filename (str): Base name for the output pickle file. max_doas (int): Maximum number of Directions of Arrival (DOAs). sample_range (np.array): Array specifying the number of samples for each DOA combination. data_type (str): Type of data to generate ('train' or 'test'). + resolution_range (str): Range of angular resolutions to consider : standard_resolution, fine_resolution or coarse_resolution. Returns: dict: Generated data dictionary containing association matrices and related information. @@ -79,9 +82,27 @@ def generate_data(pickle_filename, max_doas, sample_range, data_type="train"): # Metadata Container combination_counts = {} + # Modify list_resolutions based on resolution_range and data_type + if data_type == "train": + if resolution_range == "standard_resolution": + list_resolutions = [1, 2, 3, 4, 5, 10, 15, 20, 30] + elif resolution_range == "fine_resolution": + list_resolutions = [1] * 9 # [1, 1, 1, 1, 1, 1, 1, 1, 1] + elif resolution_range == "coarse_resolution": + list_resolutions = [30] * 9 # [30, 30, 30, 30, 30, 30, 30, 30, 30] + else: + raise ValueError( + "Invalid resolution_range: choose 'standard_resolution', 'fine_resolution', or 'coarse_resolution'." + ) + else: + # For test data, always use standard resolutions + list_resolutions = [1, 2, 3, 4, 5, 10, 15, 20, 30] + + print(f"Resolution Range: {resolution_range}, List Resolutions: {list_resolutions}") + # For each combination of associations ('nb_ref', 'nb_pred') = [(0, 0), (0, 1), (1, 0) ... (max_doas, max_doas)] # Generate random reference ('ref_ang') and prediction ('pred_ang') DOAs at different 'resolution'. - for resolution in [1, 2, 3, 4, 5, 10, 15, 20, 30]: # Different angular resolution + for resolution in list_resolutions: # Different angular resolution azi_range = range(-180, 180, resolution) ele_range = range(-90, 91, resolution) @@ -191,7 +212,7 @@ def generate_data(pickle_filename, max_doas, sample_range, data_type="train"): os.makedirs(f"data/{current_date}/{data_type}", exist_ok=True) # Human-readable filename - out_filename = f"data/{current_date}/{data_type}/{pickle_filename}_{data_type}_DOA{max_doas}_{'-'.join(map(str, sample_range))}" + out_filename = f"data/{current_date}/{data_type}/{resolution_range}_{data_type}_DOA{max_doas}_{'-'.join(map(str, sample_range))}" print(f"Saving data in: {out_filename}, #examples: {len(data_dict)}") save_obj(data_dict, out_filename) @@ -215,7 +236,10 @@ def generate_data(pickle_filename, max_doas, sample_range, data_type="train"): def main( - pickle_filename="hung_data", sample_range=np.array([3000, 5000, 15000]), max_doas=2 + sample_range=default_sample_range, + max_doas=2, + resolution_range="standard_resolution", + testing="default", ): """ Generates and saves training and testing datasets for the Hungarian Network (HNet) model. @@ -225,11 +249,12 @@ def main( defined sample ranges for different Directions of Arrival (DOAs). Args: - pickle_filename (str, optional): Base name for the output pickle files. Defaults to "hung_data". sample_range (np.array, optional): Array specifying the number of samples for each DOA combination. Should correspond to the minimum of `nb_ref` and `nb_pred`. - Defaults to [3000, 5000, 15000]. + Defaults to `default_sample_range`. max_doas (int, optional): Maximum number of Directions of Arrival (DOAs) to consider. Defaults to 2. + resolution_range (str, optional): Range of angular resolutions to consider : standard_resolution, fine_resolution or coarse_resolution. Defaults to "standard_resolution". + testing (str, optional): Which data distribution to test, sample_range or `default_sample_range`. Defaults to "default". Returns: None @@ -257,13 +282,22 @@ def main( print("\nGenerating Training Data...") # Generate training data train_data_dict = generate_data( - pickle_filename, max_doas, sample_range, data_type="train" + max_doas, sample_range, data_type="train", resolution_range=resolution_range + ) + + sample_range_to_test = ( + sample_range + if testing != "default" + else default_sample_range # Default sample range for testing ) print("\nGenerating Testing Data...") # Generate testing data, same procedure as above test_data_dict = generate_data( - pickle_filename, max_doas, sample_range, data_type="test" + max_doas, + sample_range_to_test, + data_type="test", + resolution_range=resolution_range, ) print("\n=== Summary of Generated Datasets ===") diff --git a/tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py b/tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py index 97b4b5b..dd04871 100644 --- a/tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py +++ b/tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py @@ -6,7 +6,13 @@ @pytest.mark.scenarios_generate_data -def test_generate_data_with_various_distributions(sample_range, max_doas): +@pytest.mark.parametrize( + "resolution_range", + [("standard_resolution"), ("fine_resolution"), ("coarse_resolution")], +) +def test_generate_data_with_various_distributions( + sample_range, max_doas, resolution_range +): """ Parameterized test to generate data with different sample ranges and verify distributions. @@ -14,4 +20,6 @@ def test_generate_data_with_various_distributions(sample_range, max_doas): sample_range (np.array): Array specifying the number of samples for each DOA combination. max_doas (int): Maximum number of Directions of Arrival (DOAs). """ - main(sample_range=sample_range, max_doas=max_doas) + main( + sample_range=sample_range, max_doas=max_doas, resolution_range=resolution_range + )