Skip to content

Commit

Permalink
Revert "Add resolution_range parameter and update tests in generate_h…
Browse files Browse the repository at this point in the history
…net_training_data.py"

This reverts commit 1fd00a5.
  • Loading branch information
MaloOLIVIER committed Dec 16, 2024
1 parent 320be4e commit 3a88107
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 53 deletions.
52 changes: 9 additions & 43 deletions generate_hnet_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from scipy.optimize import linear_sum_assignment
from scipy.spatial import distance

default_sample_range = np.array([3000, 5000, 15000])

def sph2cart(azimuth, elevation, r):
"""
Expand Down Expand Up @@ -57,17 +56,15 @@ def compute_class_imbalance(data_dict):
return class_counts


def generate_data(
max_doas, sample_range, data_type="train", resolution_range="standard_resolution"
):
def generate_data(pickle_filename, max_doas, sample_range, data_type="train"):
"""
Generates training or testing data based on the specified parameters.
Args:
pickle_filename (str): Base name for the output pickle file.
max_doas (int): Maximum number of Directions of Arrival (DOAs).
sample_range (np.array): Array specifying the number of samples for each DOA combination.
data_type (str): Type of data to generate ('train' or 'test').
resolution_range (str): Range of angular resolutions to consider : standard_resolution, fine_resolution or coarse_resolution.
Returns:
dict: Generated data dictionary containing association matrices and related information.
Expand All @@ -82,27 +79,9 @@ def generate_data(
# Metadata Container
combination_counts = {}

# Modify list_resolutions based on resolution_range and data_type
if data_type == "train":
if resolution_range == "standard_resolution":
list_resolutions = [1, 2, 3, 4, 5, 10, 15, 20, 30]
elif resolution_range == "fine_resolution":
list_resolutions = [1] * 9 # [1, 1, 1, 1, 1, 1, 1, 1, 1]
elif resolution_range == "coarse_resolution":
list_resolutions = [30] * 9 # [30, 30, 30, 30, 30, 30, 30, 30, 30]
else:
raise ValueError(
"Invalid resolution_range: choose 'standard_resolution', 'fine_resolution', or 'coarse_resolution'."
)
else:
# For test data, always use standard resolutions
list_resolutions = [1, 2, 3, 4, 5, 10, 15, 20, 30]

print(f"Resolution Range: {resolution_range}, List Resolutions: {list_resolutions}")

# For each combination of associations ('nb_ref', 'nb_pred') = [(0, 0), (0, 1), (1, 0) ... (max_doas, max_doas)]
# Generate random reference ('ref_ang') and prediction ('pred_ang') DOAs at different 'resolution'.
for resolution in list_resolutions: # Different angular resolution
for resolution in [1, 2, 3, 4, 5, 10, 15, 20, 30]: # Different angular resolution
azi_range = range(-180, 180, resolution)
ele_range = range(-90, 91, resolution)

Expand Down Expand Up @@ -212,7 +191,7 @@ def generate_data(
os.makedirs(f"data/{current_date}/{data_type}", exist_ok=True)

# Human-readable filename
out_filename = f"data/{current_date}/{data_type}/{resolution_range}_{data_type}_DOA{max_doas}_{'-'.join(map(str, sample_range))}"
out_filename = f"data/{current_date}/{data_type}/{pickle_filename}_{data_type}_DOA{max_doas}_{'-'.join(map(str, sample_range))}"

print(f"Saving data in: {out_filename}, #examples: {len(data_dict)}")
save_obj(data_dict, out_filename)
Expand All @@ -236,10 +215,7 @@ def generate_data(


def main(
sample_range=default_sample_range,
max_doas=2,
resolution_range="standard_resolution",
testing="default",
pickle_filename="hung_data", sample_range=np.array([3000, 5000, 15000]), max_doas=2
):
"""
Generates and saves training and testing datasets for the Hungarian Network (HNet) model.
Expand All @@ -249,12 +225,11 @@ def main(
defined sample ranges for different Directions of Arrival (DOAs).
Args:
pickle_filename (str, optional): Base name for the output pickle files. Defaults to "hung_data".
sample_range (np.array, optional): Array specifying the number of samples for each DOA combination.
Should correspond to the minimum of `nb_ref` and `nb_pred`.
Defaults to `default_sample_range`.
Defaults to [3000, 5000, 15000].
max_doas (int, optional): Maximum number of Directions of Arrival (DOAs) to consider. Defaults to 2.
resolution_range (str, optional): Range of angular resolutions to consider : standard_resolution, fine_resolution or coarse_resolution. Defaults to "standard_resolution".
testing (str, optional): Which data distribution to test, sample_range or `default_sample_range`. Defaults to "default".
Returns:
None
Expand Down Expand Up @@ -282,22 +257,13 @@ def main(
print("\nGenerating Training Data...")
# Generate training data
train_data_dict = generate_data(
max_doas, sample_range, data_type="train", resolution_range=resolution_range
)

sample_range_to_test = (
sample_range
if testing != "default"
else default_sample_range # Default sample range for testing
pickle_filename, max_doas, sample_range, data_type="train"
)

print("\nGenerating Testing Data...")
# Generate testing data, same procedure as above
test_data_dict = generate_data(
max_doas,
sample_range_to_test,
data_type="test",
resolution_range=resolution_range,
pickle_filename, max_doas, sample_range, data_type="test"
)

print("\n=== Summary of Generated Datasets ===")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,12 @@


@pytest.mark.scenarios_generate_data
@pytest.mark.parametrize(
"resolution_range",
[("standard_resolution"), ("fine_resolution"), ("coarse_resolution")],
)
def test_generate_data_with_various_distributions(
sample_range, max_doas, resolution_range
):
def test_generate_data_with_various_distributions(sample_range, max_doas):
"""
Parameterized test to generate data with different sample ranges and verify distributions.
Args:
sample_range (np.array): Array specifying the number of samples for each DOA combination.
max_doas (int): Maximum number of Directions of Arrival (DOAs).
"""
main(
sample_range=sample_range, max_doas=max_doas, resolution_range=resolution_range
)
main(sample_range=sample_range, max_doas=max_doas)

0 comments on commit 3a88107

Please sign in to comment.