Skip to content

Commit 966ecfc

Browse files
committed
calculate avg of mAP and recall
1 parent c920678 commit 966ecfc

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

mmda/baselines/emma_kitti_class.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,24 @@ def retrieve_data(
456456
precisions[20].append(precision_20)
457457
maps[5].append(ap_5)
458458
maps[20].append(ap_20)
459-
return maps, precisions, recalls
459+
460+
maps_dict = {5: np.mean(maps[5]), 20: np.mean(maps[20])}
461+
precisions_dict = {
462+
1: np.mean(precisions[1]),
463+
5: np.mean(precisions[5]),
464+
20: np.mean(precisions[20]),
465+
}
466+
recalls_dict = {
467+
1: np.mean(recalls[1]),
468+
5: np.mean(recalls[5]),
469+
20: np.mean(recalls[20]),
470+
}
471+
return maps_dict, precisions_dict, recalls_dict
460472

461473

462474
if __name__ == "__main__":
463475
# CUDA_VISIBLE_DEVICES=2 poetry run python mmda/baselines/emma_ds_class.py
476+
import pandas as pd
464477
from omegaconf import OmegaConf
465478

466479
cfg = OmegaConf.load("config/main.yaml")
@@ -477,3 +490,22 @@ def retrieve_data(
477490
ds.txtdata["test"] = txt_transformed
478491
maps, precisions, recalls = ds.retrieve_data()
479492
print(maps, precisions, recalls)
493+
# write the results to a csv file
494+
data = {
495+
"method": [
496+
"EMMA",
497+
],
498+
"mAP@5": [maps[5]],
499+
"mAP@20": [maps[20]],
500+
"Precision@1": [precisions[1]],
501+
"Precision@5": [precisions[5]],
502+
"Precision@20": [precisions[20]],
503+
"Recall@1": [recalls[1]],
504+
"Recall@5": [recalls[5]],
505+
"Recall@20": [recalls[20]],
506+
}
507+
df = pd.DataFrame(data)
508+
dir_path = Path(cfg.KITTI.paths.plots_path)
509+
df_path = dir_path / "emma_kitti_class.csv"
510+
df_path.parent.mkdir(parents=True, exist_ok=True)
511+
df.to_csv(df_path, index=False)

0 commit comments

Comments
 (0)