-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_effect_decomposition.py
More file actions
65 lines (51 loc) · 2 KB
/
plot_effect_decomposition.py
File metadata and controls
65 lines (51 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import matplotlib.pyplot as plt
import numpy as np
# Primary settings only
labels = ["GPQA-Pair-1", "GPQA-Pair-2", "HLE-Pair-1", "HLE-Pair-2", "LCB-Pair-1", "LCB-Pair-2"]
# Effect decomposition (percentage points)
re_solving = np.array([14.6, 56.6, 4.9, 30.4, 10.5, 14.6])
scaffold = np.array([-1.5, 0.5, 0.9, -1.8, 25.9, 42.9])
content = np.array([4.5, -3.0, 1.8, -2.4, -3.1, -7.9])
x = np.arange(len(labels))
w = 0.24
fig, ax = plt.subplots(figsize=(11, 5.8))
bars1 = ax.bar(x - w, re_solving, width=w, label="Re-solving")
bars2 = ax.bar(x, scaffold, width=w, label="Scaffold")
bars3 = ax.bar(x + w, content, width=w, label="Content")
# Zero reference line
ax.axhline(0, linewidth=1)
# Axes and labels
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=20, ha="right")
ax.set_ylabel("Effect size (pp)")
ax.legend(frameon=False, loc="upper right")
# Value labels
def add_labels(bars):
for b in bars:
h = b.get_height()
xpos = b.get_x() + b.get_width() / 2
if h >= 0:
ax.text(xpos, h + 1.0, f"{h:.1f}", ha="center", va="bottom", fontsize=9)
else:
ax.text(xpos, h - 1.2, f"{h:.1f}", ha="center", va="top", fontsize=9)
add_labels(bars1)
add_labels(bars2)
add_labels(bars3)
# Vertical separators between datasets
for sep in [1.5, 3.5]:
ax.axvline(sep, linestyle="--", linewidth=0.8, alpha=0.6)
# Dataset group labels
y_top = max(np.max(re_solving), np.max(scaffold), np.max(content)) + 9
ax.text(0.5, y_top, "GPQA", ha="center", va="bottom", fontsize=11)
ax.text(2.5, y_top, "HLE", ha="center", va="bottom", fontsize=11)
ax.text(4.5, y_top, "LiveCodeBench", ha="center", va="bottom", fontsize=11)
# Y limits
ax.set_ylim(
min(np.min(re_solving), np.min(scaffold), np.min(content)) - 10,
max(np.max(re_solving), np.max(scaffold), np.max(content)) + 14
)
plt.tight_layout()
# Optional:
# plt.savefig("primary_decomposition_bar_chart.png", dpi=300, bbox_inches="tight")
plt.savefig("figures/primary_decomposition_bar_chart.pdf", bbox_inches="tight")
plt.show()