Skip to content

Commit 465a468

Browse files
authoredJun 8, 2022
Create draw_curves.py
1 parent f50b1f4 commit 465a468

File tree

1 file changed

+233
-0
lines changed

1 file changed

+233
-0
lines changed
 

‎draw_curves.py

+233
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
import os
2+
import pandas as pd
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
6+
7+
save_dir = "main_curves"
8+
if not os.path.exists(save_dir):
9+
os.makedirs(save_dir)
10+
11+
path = "Results.xlsx" # this is the excel file containing the results (like the one we released)
12+
file = pd.read_excel(path, sheet_name="imcls_fewshot")
13+
14+
datasets = [
15+
"OxfordPets", "Flowers102", "FGVCAircraft", "DTD",
16+
"EuroSAT", "StanfordCars", "Food101", "SUN397",
17+
"Caltech101", "UCF101", "ImageNet"
18+
]
19+
20+
shots = [1, 2, 4, 8, 16]
21+
22+
COLORS = {
23+
"zs": "C4",
24+
"linear": "C4",
25+
"ours_v16_end": "C0",
26+
"ours_v16_mid": "C2",
27+
"ours_v16_end_csc": "C1",
28+
"ours_v16_mid_csc": "C3"
29+
}
30+
MS = 3
31+
ALPHA = 1
32+
plt.rcParams.update({"font.size": 12})
33+
34+
average = {
35+
"zs": 0.,
36+
"ours_v16_end": np.array([0., 0., 0., 0., 0.]),
37+
"ours_v16_mid": np.array([0., 0., 0., 0., 0.]),
38+
"ours_v16_end_csc": np.array([0., 0., 0., 0., 0.]),
39+
"ours_v16_mid_csc": np.array([0., 0., 0., 0., 0.]),
40+
"linear": np.array([0., 0., 0., 0., 0.])
41+
}
42+
43+
for dataset in datasets:
44+
print(f"Processing {dataset} ...")
45+
46+
zs = file[dataset][0]
47+
48+
ours_v16_end = file[dataset][2:7]
49+
ours_v16_end = [float(num) for num in ours_v16_end]
50+
51+
ours_v16_mid = file[dataset][7:12]
52+
ours_v16_mid = [float(num) for num in ours_v16_mid]
53+
54+
ours_v16_end_csc = file[dataset][12:17]
55+
ours_v16_end_csc = [float(num) for num in ours_v16_end_csc]
56+
57+
ours_v16_mid_csc = file[dataset][17:22]
58+
ours_v16_mid_csc = [float(num) for num in ours_v16_mid_csc]
59+
60+
linear = file[dataset][22:27]
61+
linear = [float(num) for num in linear]
62+
63+
average["zs"] += zs
64+
average["ours_v16_end"] += np.array(ours_v16_end)
65+
average["ours_v16_mid"] += np.array(ours_v16_mid)
66+
average["ours_v16_end_csc"] += np.array(ours_v16_end_csc)
67+
average["ours_v16_mid_csc"] += np.array(ours_v16_mid_csc)
68+
average["linear"] += np.array(linear)
69+
70+
# Plot
71+
values = [zs]
72+
values += linear
73+
values += ours_v16_end
74+
values += ours_v16_mid
75+
values += ours_v16_end_csc
76+
values += ours_v16_mid_csc
77+
val_min, val_max = min(values), max(values)
78+
diff = val_max - val_min
79+
val_bot = val_min - diff*0.05
80+
val_top = val_max + diff*0.05
81+
82+
fig, ax = plt.subplots()
83+
ax.set_facecolor("#EBEBEB")
84+
85+
ax.set_xticks([0] + shots)
86+
ax.set_xticklabels([0] + shots)
87+
ax.set_xlabel("Number of labeled training examples per class")
88+
ax.set_ylabel("Score (%)")
89+
ax.grid(axis="x", color="white", linewidth=1)
90+
ax.axhline(zs, color="white", linewidth=1)
91+
ax.set_title(dataset)
92+
ax.set_ylim(val_bot, val_top)
93+
94+
ax.plot(
95+
0, zs,
96+
marker="*",
97+
markersize=MS*1.5,
98+
color=COLORS["zs"],
99+
alpha=ALPHA
100+
)
101+
ax.plot(
102+
shots, ours_v16_end,
103+
marker="o",
104+
markersize=MS,
105+
color=COLORS["ours_v16_end"],
106+
label="CLIP + CoOp ($M\!=\!16$, end)",
107+
alpha=ALPHA
108+
)
109+
ax.plot(
110+
shots, ours_v16_mid,
111+
marker="o",
112+
markersize=MS,
113+
color=COLORS["ours_v16_mid"],
114+
label="CLIP + CoOp ($M\!=\!16$, mid)",
115+
alpha=ALPHA
116+
)
117+
ax.plot(
118+
shots, ours_v16_end_csc,
119+
marker="o",
120+
markersize=MS,
121+
color=COLORS["ours_v16_end_csc"],
122+
label="CLIP + CoOp ($M\!=\!16$, end, CSC)",
123+
alpha=ALPHA
124+
)
125+
ax.plot(
126+
shots, ours_v16_mid_csc,
127+
marker="o",
128+
markersize=MS,
129+
color=COLORS["ours_v16_mid_csc"],
130+
label="CLIP + CoOp ($M\!=\!16$, mid, CSC)",
131+
alpha=ALPHA
132+
)
133+
ax.plot(
134+
shots, linear,
135+
marker="o",
136+
markersize=MS,
137+
color=COLORS["linear"],
138+
label="Linear probe CLIP",
139+
linestyle="dotted",
140+
alpha=ALPHA
141+
)
142+
143+
ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"])
144+
ax.legend(loc="lower right")
145+
146+
fig.savefig(f"{save_dir}/{dataset}.pdf", bbox_inches="tight")
147+
148+
149+
# Plot
150+
average = {k: v/len(datasets) for k, v in average.items()}
151+
zs = average["zs"]
152+
linear = list(average["linear"])
153+
ours_v16_end = list(average["ours_v16_end"])
154+
ours_v16_mid = list(average["ours_v16_mid"])
155+
ours_v16_end_csc = list(average["ours_v16_end_csc"])
156+
ours_v16_mid_csc = list(average["ours_v16_mid_csc"])
157+
158+
values = [zs]
159+
values += linear
160+
values += ours_v16_end
161+
values += ours_v16_mid
162+
values += ours_v16_end_csc
163+
values += ours_v16_mid_csc
164+
val_min, val_max = min(values), max(values)
165+
diff = val_max - val_min
166+
val_bot = val_min - diff*0.05
167+
val_top = val_max + diff*0.05
168+
169+
fig, ax = plt.subplots()
170+
ax.set_facecolor("#EBEBEB")
171+
172+
ax.set_xticks([0] + shots)
173+
ax.set_xticklabels([0] + shots)
174+
ax.set_xlabel("Number of labeled training examples per class")
175+
ax.set_ylabel("Score (%)")
176+
ax.grid(axis="x", color="white", linewidth=1)
177+
ax.axhline(zs, color="white", linewidth=1)
178+
ax.set_title("Average over 11 datasets", fontweight="bold")
179+
ax.set_ylim(val_bot, val_top)
180+
181+
ax.plot(
182+
0, zs,
183+
marker="*",
184+
markersize=MS*1.5,
185+
color=COLORS["zs"],
186+
alpha=ALPHA
187+
)
188+
ax.plot(
189+
shots, ours_v16_end,
190+
marker="o",
191+
markersize=MS,
192+
color=COLORS["ours_v16_end"],
193+
label="CLIP + CoOp ($M\!=\!16$, end)",
194+
alpha=ALPHA
195+
)
196+
ax.plot(
197+
shots, ours_v16_mid,
198+
marker="o",
199+
markersize=MS,
200+
color=COLORS["ours_v16_mid"],
201+
label="CLIP + CoOp ($M\!=\!16$, mid)",
202+
alpha=ALPHA
203+
)
204+
ax.plot(
205+
shots, ours_v16_end_csc,
206+
marker="o",
207+
markersize=MS,
208+
color=COLORS["ours_v16_end_csc"],
209+
label="CLIP + CoOp ($M\!=\!16$, end, CSC)",
210+
alpha=ALPHA
211+
)
212+
ax.plot(
213+
shots, ours_v16_mid_csc,
214+
marker="o",
215+
markersize=MS,
216+
color=COLORS["ours_v16_mid_csc"],
217+
label="CLIP + CoOp ($M\!=\!16$, mid, CSC)",
218+
alpha=ALPHA
219+
)
220+
ax.plot(
221+
shots, linear,
222+
marker="o",
223+
markersize=MS,
224+
color=COLORS["linear"],
225+
label="Linear probe CLIP",
226+
linestyle="dotted",
227+
alpha=ALPHA
228+
)
229+
230+
ax.text(-0.5, zs-diff*0.11, "Zero-shot\nCLIP", color=COLORS["zs"])
231+
ax.legend(loc="lower right")
232+
233+
fig.savefig(f"{save_dir}/average.pdf", bbox_inches="tight")

0 commit comments

Comments
 (0)
Please sign in to comment.