|
| 1 | +import os |
| 2 | +import pandas as pd |
| 3 | +import numpy as np |
| 4 | +import matplotlib.pyplot as plt |
| 5 | + |
| 6 | + |
| 7 | +save_dir = "main_curves" |
| 8 | +if not os.path.exists(save_dir): |
| 9 | + os.makedirs(save_dir) |
| 10 | + |
| 11 | +path = "Results.xlsx" # this is the excel file containing the results (like the one we released) |
| 12 | +file = pd.read_excel(path, sheet_name="imcls_fewshot") |
| 13 | + |
| 14 | +datasets = [ |
| 15 | + "OxfordPets", "Flowers102", "FGVCAircraft", "DTD", |
| 16 | + "EuroSAT", "StanfordCars", "Food101", "SUN397", |
| 17 | + "Caltech101", "UCF101", "ImageNet" |
| 18 | +] |
| 19 | + |
| 20 | +shots = [1, 2, 4, 8, 16] |
| 21 | + |
| 22 | +COLORS = { |
| 23 | + "zs": "C4", |
| 24 | + "linear": "C4", |
| 25 | + "ours_v16_end": "C0", |
| 26 | + "ours_v16_mid": "C2", |
| 27 | + "ours_v16_end_csc": "C1", |
| 28 | + "ours_v16_mid_csc": "C3" |
| 29 | +} |
| 30 | +MS = 3 |
| 31 | +ALPHA = 1 |
| 32 | +plt.rcParams.update({"font.size": 12}) |
| 33 | + |
| 34 | +average = { |
| 35 | + "zs": 0., |
| 36 | + "ours_v16_end": np.array([0., 0., 0., 0., 0.]), |
| 37 | + "ours_v16_mid": np.array([0., 0., 0., 0., 0.]), |
| 38 | + "ours_v16_end_csc": np.array([0., 0., 0., 0., 0.]), |
| 39 | + "ours_v16_mid_csc": np.array([0., 0., 0., 0., 0.]), |
| 40 | + "linear": np.array([0., 0., 0., 0., 0.]) |
| 41 | +} |
| 42 | + |
| 43 | +for dataset in datasets: |
| 44 | + print(f"Processing {dataset} ...") |
| 45 | + |
| 46 | + zs = file[dataset][0] |
| 47 | + |
| 48 | + ours_v16_end = file[dataset][2:7] |
| 49 | + ours_v16_end = [float(num) for num in ours_v16_end] |
| 50 | + |
| 51 | + ours_v16_mid = file[dataset][7:12] |
| 52 | + ours_v16_mid = [float(num) for num in ours_v16_mid] |
| 53 | + |
| 54 | + ours_v16_end_csc = file[dataset][12:17] |
| 55 | + ours_v16_end_csc = [float(num) for num in ours_v16_end_csc] |
| 56 | + |
| 57 | + ours_v16_mid_csc = file[dataset][17:22] |
| 58 | + ours_v16_mid_csc = [float(num) for num in ours_v16_mid_csc] |
| 59 | + |
| 60 | + linear = file[dataset][22:27] |
| 61 | + linear = [float(num) for num in linear] |
| 62 | + |
| 63 | + average["zs"] += zs |
| 64 | + average["ours_v16_end"] += np.array(ours_v16_end) |
| 65 | + average["ours_v16_mid"] += np.array(ours_v16_mid) |
| 66 | + average["ours_v16_end_csc"] += np.array(ours_v16_end_csc) |
| 67 | + average["ours_v16_mid_csc"] += np.array(ours_v16_mid_csc) |
| 68 | + average["linear"] += np.array(linear) |
| 69 | + |
| 70 | + # Plot |
| 71 | + values = [zs] |
| 72 | + values += linear |
| 73 | + values += ours_v16_end |
| 74 | + values += ours_v16_mid |
| 75 | + values += ours_v16_end_csc |
| 76 | + values += ours_v16_mid_csc |
| 77 | + val_min, val_max = min(values), max(values) |
| 78 | + diff = val_max - val_min |
| 79 | + val_bot = val_min - diff*0.05 |
| 80 | + val_top = val_max + diff*0.05 |
| 81 | + |
| 82 | + fig, ax = plt.subplots() |
| 83 | + ax.set_facecolor("#EBEBEB") |
| 84 | + |
| 85 | + ax.set_xticks([0] + shots) |
| 86 | + ax.set_xticklabels([0] + shots) |
| 87 | + ax.set_xlabel("Number of labeled training examples per class") |
| 88 | + ax.set_ylabel("Score (%)") |
| 89 | + ax.grid(axis="x", color="white", linewidth=1) |
| 90 | + ax.axhline(zs, color="white", linewidth=1) |
| 91 | + ax.set_title(dataset) |
| 92 | + ax.set_ylim(val_bot, val_top) |
| 93 | + |
| 94 | + ax.plot( |
| 95 | + 0, zs, |
| 96 | + marker="*", |
| 97 | + markersize=MS*1.5, |
| 98 | + color=COLORS["zs"], |
| 99 | + alpha=ALPHA |
| 100 | + ) |
| 101 | + ax.plot( |
| 102 | + shots, ours_v16_end, |
| 103 | + marker="o", |
| 104 | + markersize=MS, |
| 105 | + color=COLORS["ours_v16_end"], |
| 106 | + label="CLIP + CoOp ($M\!=\!16$, end)", |
| 107 | + alpha=ALPHA |
| 108 | + ) |
| 109 | + ax.plot( |
| 110 | + shots, ours_v16_mid, |
| 111 | + marker="o", |
| 112 | + markersize=MS, |
| 113 | + color=COLORS["ours_v16_mid"], |
| 114 | + label="CLIP + CoOp ($M\!=\!16$, mid)", |
| 115 | + alpha=ALPHA |
| 116 | + ) |
| 117 | + ax.plot( |
| 118 | + shots, ours_v16_end_csc, |
| 119 | + marker="o", |
| 120 | + markersize=MS, |
| 121 | + color=COLORS["ours_v16_end_csc"], |
| 122 | + label="CLIP + CoOp ($M\!=\!16$, end, CSC)", |
| 123 | + alpha=ALPHA |
| 124 | + ) |
| 125 | + ax.plot( |
| 126 | + shots, ours_v16_mid_csc, |
| 127 | + marker="o", |
| 128 | + markersize=MS, |
| 129 | + color=COLORS["ours_v16_mid_csc"], |
| 130 | + label="CLIP + CoOp ($M\!=\!16$, mid, CSC)", |
| 131 | + alpha=ALPHA |
| 132 | + ) |
| 133 | + ax.plot( |
| 134 | + shots, linear, |
| 135 | + marker="o", |
| 136 | + markersize=MS, |
| 137 | + color=COLORS["linear"], |
| 138 | + label="Linear probe CLIP", |
| 139 | + linestyle="dotted", |
| 140 | + alpha=ALPHA |
| 141 | + ) |
| 142 | + |
| 143 | + ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"]) |
| 144 | + ax.legend(loc="lower right") |
| 145 | + |
| 146 | + fig.savefig(f"{save_dir}/{dataset}.pdf", bbox_inches="tight") |
| 147 | + |
| 148 | + |
| 149 | +# Plot |
| 150 | +average = {k: v/len(datasets) for k, v in average.items()} |
| 151 | +zs = average["zs"] |
| 152 | +linear = list(average["linear"]) |
| 153 | +ours_v16_end = list(average["ours_v16_end"]) |
| 154 | +ours_v16_mid = list(average["ours_v16_mid"]) |
| 155 | +ours_v16_end_csc = list(average["ours_v16_end_csc"]) |
| 156 | +ours_v16_mid_csc = list(average["ours_v16_mid_csc"]) |
| 157 | + |
| 158 | +values = [zs] |
| 159 | +values += linear |
| 160 | +values += ours_v16_end |
| 161 | +values += ours_v16_mid |
| 162 | +values += ours_v16_end_csc |
| 163 | +values += ours_v16_mid_csc |
| 164 | +val_min, val_max = min(values), max(values) |
| 165 | +diff = val_max - val_min |
| 166 | +val_bot = val_min - diff*0.05 |
| 167 | +val_top = val_max + diff*0.05 |
| 168 | + |
| 169 | +fig, ax = plt.subplots() |
| 170 | +ax.set_facecolor("#EBEBEB") |
| 171 | + |
| 172 | +ax.set_xticks([0] + shots) |
| 173 | +ax.set_xticklabels([0] + shots) |
| 174 | +ax.set_xlabel("Number of labeled training examples per class") |
| 175 | +ax.set_ylabel("Score (%)") |
| 176 | +ax.grid(axis="x", color="white", linewidth=1) |
| 177 | +ax.axhline(zs, color="white", linewidth=1) |
| 178 | +ax.set_title("Average over 11 datasets", fontweight="bold") |
| 179 | +ax.set_ylim(val_bot, val_top) |
| 180 | + |
| 181 | +ax.plot( |
| 182 | + 0, zs, |
| 183 | + marker="*", |
| 184 | + markersize=MS*1.5, |
| 185 | + color=COLORS["zs"], |
| 186 | + alpha=ALPHA |
| 187 | +) |
| 188 | +ax.plot( |
| 189 | + shots, ours_v16_end, |
| 190 | + marker="o", |
| 191 | + markersize=MS, |
| 192 | + color=COLORS["ours_v16_end"], |
| 193 | + label="CLIP + CoOp ($M\!=\!16$, end)", |
| 194 | + alpha=ALPHA |
| 195 | +) |
| 196 | +ax.plot( |
| 197 | + shots, ours_v16_mid, |
| 198 | + marker="o", |
| 199 | + markersize=MS, |
| 200 | + color=COLORS["ours_v16_mid"], |
| 201 | + label="CLIP + CoOp ($M\!=\!16$, mid)", |
| 202 | + alpha=ALPHA |
| 203 | +) |
| 204 | +ax.plot( |
| 205 | + shots, ours_v16_end_csc, |
| 206 | + marker="o", |
| 207 | + markersize=MS, |
| 208 | + color=COLORS["ours_v16_end_csc"], |
| 209 | + label="CLIP + CoOp ($M\!=\!16$, end, CSC)", |
| 210 | + alpha=ALPHA |
| 211 | +) |
| 212 | +ax.plot( |
| 213 | + shots, ours_v16_mid_csc, |
| 214 | + marker="o", |
| 215 | + markersize=MS, |
| 216 | + color=COLORS["ours_v16_mid_csc"], |
| 217 | + label="CLIP + CoOp ($M\!=\!16$, mid, CSC)", |
| 218 | + alpha=ALPHA |
| 219 | +) |
| 220 | +ax.plot( |
| 221 | + shots, linear, |
| 222 | + marker="o", |
| 223 | + markersize=MS, |
| 224 | + color=COLORS["linear"], |
| 225 | + label="Linear probe CLIP", |
| 226 | + linestyle="dotted", |
| 227 | + alpha=ALPHA |
| 228 | +) |
| 229 | + |
| 230 | +ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"]) |
| 231 | +ax.legend(loc="lower right") |
| 232 | + |
| 233 | +fig.savefig(f"{save_dir}/average.pdf", bbox_inches="tight") |
0 commit comments