From 3a881076511bb5ff36d0adf443771f8046205bd4 Mon Sep 17 00:00:00 2001 From: Malo OLIVIER Date: Mon, 16 Dec 2024 18:55:35 +0100 Subject: [PATCH] Revert "Add resolution_range parameter and update tests in generate_hnet_training_data.py" This reverts commit 1fd00a509955dac01f8f098f3b13e5ae7ba53265. --- generate_hnet_training_data.py | 52 ++++--------------- ...t_scenarios_generate_hnet_training_data.py | 12 +---- 2 files changed, 11 insertions(+), 53 deletions(-) diff --git a/generate_hnet_training_data.py b/generate_hnet_training_data.py index 7ee09f4..84ed9f8 100644 --- a/generate_hnet_training_data.py +++ b/generate_hnet_training_data.py @@ -11,7 +11,6 @@ from scipy.optimize import linear_sum_assignment from scipy.spatial import distance -default_sample_range = np.array([3000, 5000, 15000]) def sph2cart(azimuth, elevation, r): """ @@ -57,17 +56,15 @@ def compute_class_imbalance(data_dict): return class_counts -def generate_data( - max_doas, sample_range, data_type="train", resolution_range="standard_resolution" -): +def generate_data(pickle_filename, max_doas, sample_range, data_type="train"): """ Generates training or testing data based on the specified parameters. Args: + pickle_filename (str): Base name for the output pickle file. max_doas (int): Maximum number of Directions of Arrival (DOAs). sample_range (np.array): Array specifying the number of samples for each DOA combination. data_type (str): Type of data to generate ('train' or 'test'). - resolution_range (str): Range of angular resolutions to consider : standard_resolution, fine_resolution or coarse_resolution. Returns: dict: Generated data dictionary containing association matrices and related information. @@ -82,27 +79,9 @@ def generate_data( # Metadata Container combination_counts = {} - # Modify list_resolutions based on resolution_range and data_type - if data_type == "train": - if resolution_range == "standard_resolution": - list_resolutions = [1, 2, 3, 4, 5, 10, 15, 20, 30] - elif resolution_range == "fine_resolution": - list_resolutions = [1] * 9 # [1, 1, 1, 1, 1, 1, 1, 1, 1] - elif resolution_range == "coarse_resolution": - list_resolutions = [30] * 9 # [30, 30, 30, 30, 30, 30, 30, 30, 30] - else: - raise ValueError( - "Invalid resolution_range: choose 'standard_resolution', 'fine_resolution', or 'coarse_resolution'." - ) - else: - # For test data, always use standard resolutions - list_resolutions = [1, 2, 3, 4, 5, 10, 15, 20, 30] - - print(f"Resolution Range: {resolution_range}, List Resolutions: {list_resolutions}") - # For each combination of associations ('nb_ref', 'nb_pred') = [(0, 0), (0, 1), (1, 0) ... (max_doas, max_doas)] # Generate random reference ('ref_ang') and prediction ('pred_ang') DOAs at different 'resolution'. - for resolution in list_resolutions: # Different angular resolution + for resolution in [1, 2, 3, 4, 5, 10, 15, 20, 30]: # Different angular resolution azi_range = range(-180, 180, resolution) ele_range = range(-90, 91, resolution) @@ -212,7 +191,7 @@ def generate_data( os.makedirs(f"data/{current_date}/{data_type}", exist_ok=True) # Human-readable filename - out_filename = f"data/{current_date}/{data_type}/{resolution_range}_{data_type}_DOA{max_doas}_{'-'.join(map(str, sample_range))}" + out_filename = f"data/{current_date}/{data_type}/{pickle_filename}_{data_type}_DOA{max_doas}_{'-'.join(map(str, sample_range))}" print(f"Saving data in: {out_filename}, #examples: {len(data_dict)}") save_obj(data_dict, out_filename) @@ -236,10 +215,7 @@ def generate_data( def main( - sample_range=default_sample_range, - max_doas=2, - resolution_range="standard_resolution", - testing="default", + pickle_filename="hung_data", sample_range=np.array([3000, 5000, 15000]), max_doas=2 ): """ Generates and saves training and testing datasets for the Hungarian Network (HNet) model. @@ -249,12 +225,11 @@ def main( defined sample ranges for different Directions of Arrival (DOAs). Args: + pickle_filename (str, optional): Base name for the output pickle files. Defaults to "hung_data". sample_range (np.array, optional): Array specifying the number of samples for each DOA combination. Should correspond to the minimum of `nb_ref` and `nb_pred`. - Defaults to `default_sample_range`. + Defaults to [3000, 5000, 15000]. max_doas (int, optional): Maximum number of Directions of Arrival (DOAs) to consider. Defaults to 2. - resolution_range (str, optional): Range of angular resolutions to consider : standard_resolution, fine_resolution or coarse_resolution. Defaults to "standard_resolution". - testing (str, optional): Which data distribution to test, sample_range or `default_sample_range`. Defaults to "default". Returns: None @@ -282,22 +257,13 @@ def main( print("\nGenerating Training Data...") # Generate training data train_data_dict = generate_data( - max_doas, sample_range, data_type="train", resolution_range=resolution_range - ) - - sample_range_to_test = ( - sample_range - if testing != "default" - else default_sample_range # Default sample range for testing + pickle_filename, max_doas, sample_range, data_type="train" ) print("\nGenerating Testing Data...") # Generate testing data, same procedure as above test_data_dict = generate_data( - max_doas, - sample_range_to_test, - data_type="test", - resolution_range=resolution_range, + pickle_filename, max_doas, sample_range, data_type="test" ) print("\n=== Summary of Generated Datasets ===") diff --git a/tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py b/tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py index dd04871..97b4b5b 100644 --- a/tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py +++ b/tests/scenarios_tests/generate_hnet_training_data/test_scenarios_generate_hnet_training_data.py @@ -6,13 +6,7 @@ @pytest.mark.scenarios_generate_data -@pytest.mark.parametrize( - "resolution_range", - [("standard_resolution"), ("fine_resolution"), ("coarse_resolution")], -) -def test_generate_data_with_various_distributions( - sample_range, max_doas, resolution_range -): +def test_generate_data_with_various_distributions(sample_range, max_doas): """ Parameterized test to generate data with different sample ranges and verify distributions. @@ -20,6 +14,4 @@ def test_generate_data_with_various_distributions( sample_range (np.array): Array specifying the number of samples for each DOA combination. max_doas (int): Maximum number of Directions of Arrival (DOAs). """ - main( - sample_range=sample_range, max_doas=max_doas, resolution_range=resolution_range - ) + main(sample_range=sample_range, max_doas=max_doas)