Skip to content

Commit

Permalink
Add resolution_range parameter and update tests in generate_hnet_trai…
Browse files Browse the repository at this point in the history
…ning_data.py

- Introduce resolution_range to handle standard, fine, and coarse resolutions
- Modify generate_data function to set list_resolutions based on resolution_range and data_type
- Update output filename to include resolution_range
- Add default_sample_range variable
- Update main function to accept and pass resolution_range and testing parameters
- Enhance tests by parameterizing resolution_range with various resolution options
  • Loading branch information
MaloOLIVIER committed Dec 13, 2024
1 parent 3a36ce9 commit 1fd00a5
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 11 deletions.
52 changes: 43 additions & 9 deletions generate_hnet_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from scipy.optimize import linear_sum_assignment
from scipy.spatial import distance

default_sample_range = np.array([3000, 5000, 15000])

def sph2cart(azimuth, elevation, r):
"""
Expand Down Expand Up @@ -56,15 +57,17 @@ def compute_class_imbalance(data_dict):
return class_counts


def generate_data(pickle_filename, max_doas, sample_range, data_type="train"):
def generate_data(
max_doas, sample_range, data_type="train", resolution_range="standard_resolution"
):
"""
Generates training or testing data based on the specified parameters.
Args:
pickle_filename (str): Base name for the output pickle file.
max_doas (int): Maximum number of Directions of Arrival (DOAs).
sample_range (np.array): Array specifying the number of samples for each DOA combination.
data_type (str): Type of data to generate ('train' or 'test').
resolution_range (str): Range of angular resolutions to consider : standard_resolution, fine_resolution or coarse_resolution.
Returns:
dict: Generated data dictionary containing association matrices and related information.
Expand All @@ -79,9 +82,27 @@ def generate_data(pickle_filename, max_doas, sample_range, data_type="train"):
# Metadata Container
combination_counts = {}

# Modify list_resolutions based on resolution_range and data_type
if data_type == "train":
if resolution_range == "standard_resolution":
list_resolutions = [1, 2, 3, 4, 5, 10, 15, 20, 30]
elif resolution_range == "fine_resolution":
list_resolutions = [1] * 9 # [1, 1, 1, 1, 1, 1, 1, 1, 1]
elif resolution_range == "coarse_resolution":
list_resolutions = [30] * 9 # [30, 30, 30, 30, 30, 30, 30, 30, 30]
else:
raise ValueError(
"Invalid resolution_range: choose 'standard_resolution', 'fine_resolution', or 'coarse_resolution'."
)
else:
# For test data, always use standard resolutions
list_resolutions = [1, 2, 3, 4, 5, 10, 15, 20, 30]

print(f"Resolution Range: {resolution_range}, List Resolutions: {list_resolutions}")

# For each combination of associations ('nb_ref', 'nb_pred') = [(0, 0), (0, 1), (1, 0) ... (max_doas, max_doas)]
# Generate random reference ('ref_ang') and prediction ('pred_ang') DOAs at different 'resolution'.
for resolution in [1, 2, 3, 4, 5, 10, 15, 20, 30]: # Different angular resolution
for resolution in list_resolutions: # Different angular resolution
azi_range = range(-180, 180, resolution)
ele_range = range(-90, 91, resolution)

Expand Down Expand Up @@ -191,7 +212,7 @@ def generate_data(pickle_filename, max_doas, sample_range, data_type="train"):
os.makedirs(f"data/{current_date}/{data_type}", exist_ok=True)

# Human-readable filename
out_filename = f"data/{current_date}/{data_type}/{pickle_filename}_{data_type}_DOA{max_doas}_{'-'.join(map(str, sample_range))}"
out_filename = f"data/{current_date}/{data_type}/{resolution_range}_{data_type}_DOA{max_doas}_{'-'.join(map(str, sample_range))}"

print(f"Saving data in: {out_filename}, #examples: {len(data_dict)}")
save_obj(data_dict, out_filename)
Expand All @@ -215,7 +236,10 @@ def generate_data(pickle_filename, max_doas, sample_range, data_type="train"):


def main(
pickle_filename="hung_data", sample_range=np.array([3000, 5000, 15000]), max_doas=2
sample_range=default_sample_range,
max_doas=2,
resolution_range="standard_resolution",
testing="default",
):
"""
Generates and saves training and testing datasets for the Hungarian Network (HNet) model.
Expand All @@ -225,11 +249,12 @@ def main(
defined sample ranges for different Directions of Arrival (DOAs).
Args:
pickle_filename (str, optional): Base name for the output pickle files. Defaults to "hung_data".
sample_range (np.array, optional): Array specifying the number of samples for each DOA combination.
Should correspond to the minimum of `nb_ref` and `nb_pred`.
Defaults to [3000, 5000, 15000].
Defaults to `default_sample_range`.
max_doas (int, optional): Maximum number of Directions of Arrival (DOAs) to consider. Defaults to 2.
resolution_range (str, optional): Range of angular resolutions to consider : standard_resolution, fine_resolution or coarse_resolution. Defaults to "standard_resolution".
testing (str, optional): Which data distribution to test, sample_range or `default_sample_range`. Defaults to "default".
Returns:
None
Expand Down Expand Up @@ -257,13 +282,22 @@ def main(
print("\nGenerating Training Data...")
# Generate training data
train_data_dict = generate_data(
pickle_filename, max_doas, sample_range, data_type="train"
max_doas, sample_range, data_type="train", resolution_range=resolution_range
)

sample_range_to_test = (
sample_range
if testing != "default"
else default_sample_range # Default sample range for testing
)

print("\nGenerating Testing Data...")
# Generate testing data, same procedure as above
test_data_dict = generate_data(
pickle_filename, max_doas, sample_range, data_type="test"
max_doas,
sample_range_to_test,
data_type="test",
resolution_range=resolution_range,
)

print("\n=== Summary of Generated Datasets ===")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@


@pytest.mark.scenarios_generate_data
def test_generate_data_with_various_distributions(sample_range, max_doas):
@pytest.mark.parametrize(
"resolution_range",
[("standard_resolution"), ("fine_resolution"), ("coarse_resolution")],
)
def test_generate_data_with_various_distributions(
sample_range, max_doas, resolution_range
):
"""
Parameterized test to generate data with different sample ranges and verify distributions.
Args:
sample_range (np.array): Array specifying the number of samples for each DOA combination.
max_doas (int): Maximum number of Directions of Arrival (DOAs).
"""
main(sample_range=sample_range, max_doas=max_doas)
main(
sample_range=sample_range, max_doas=max_doas, resolution_range=resolution_range
)

0 comments on commit 1fd00a5

Please sign in to comment.