Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
MaloOLIVIER committed Nov 29, 2024
1 parent babb8fe commit 4a59ad5
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@ jobs:
run: |
echo "PYTHONPATH=${PYTHONPATH}:${GITHUB_WORKSPACE}" >> $GITHUB_ENV
# - name: Run consistency tests
# run: |
# pytest -m consistency
# env:
# CI: true

- name: Run nonregression tests
- name: Run consistency tests
run: |
pytest -m nonregression
pytest -m consistency
env:
CI: true

# - name: Run nonregression tests
# run: |
# pytest -m nonregression
# env:
# CI: true

- name: Upload Coverage Report as Artifact
uses: actions/upload-artifact@v4
with:
Expand Down
10 changes: 7 additions & 3 deletions hungarian_net/plot_f1_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@

test_f = 0
nb_test_batches = 0
train_dataset = HungarianDataset(train=True, max_len=max_len, filename="data/reference/hung_data_train")
train_dataset = HungarianDataset(
train=True, max_len=max_len, filename="data/reference/hung_data_train"
)
batch_size = 256
f_score_weights = np.tile(train_dataset.get_f_wts(), batch_size)

# load test dataset
test_loader = DataLoader(
HungarianDataset(train=False, max_len=max_len, filename= "data/reference/hung_data_test"),
HungarianDataset(
train=False, max_len=max_len, filename="data/reference/hung_data_test"
),
batch_size=batch_size,
shuffle=True,
drop_last=True,
Expand Down Expand Up @@ -63,4 +67,4 @@
plot.ylim([0, 1])
plot.ylabel("F1 Score")
plot.title("F1 Score for HNet on Test Data")
plot.show()
plot.show()
2 changes: 1 addition & 1 deletion hungarian_net/train_hnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def main(
out_filename = f"models/{current_date}/hnet_model_DOA{max_len}_{'-'.join(map(str, sample_range_used))}.pt"

torch.save(model.state_dict(), out_filename)

model_to_return = model
print(
"Epoch: {}\t time: {:0.2f}/{:0.2f}\ttrain_loss: {:.4f} ({:.4f}, {:.4f}, {:.4f})\ttest_loss: {:.4f} ({:.4f}, {:.4f}, {:.4f})\tf_scr: {:.4f}\tbest_epoch: {}\tbest_f_scr: {:.4f}\ttrue_positives: {}\tfalse_positives: {}\tweighted_accuracy: {:.4f}".format(
Expand Down

0 comments on commit 4a59ad5

Please sign in to comment.