Skip to content

Commit 8f598cf

Browse files
authored
Merge pull request instructlab#2316 from alimaredia/display-eval-branch-variant-full-score
feat: add full scores to mt-bench/mmlu-branch output
2 parents caddd50 + d9b6510 commit 8f598cf

File tree

3 files changed

+146
-111
lines changed

3 files changed

+146
-111
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ datasets>=2.18.0
66
gguf>=0.6.0
77
GitPython>=3.1.42
88
httpx>=0.25.0
9-
instructlab-eval>=0.2.1
9+
instructlab-eval>=0.3.1
1010
instructlab-quantize>=0.1.0
1111
instructlab-schema>=0.4.0
1212
instructlab-sdg>=0.3.0

src/instructlab/model/evaluate.py

Lines changed: 108 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -139,24 +139,39 @@ def validate_model(model: str, model_arg: str = "--model"):
139139
raise click.exceptions.Exit(1)
140140

141141

142-
def sort_score(pairing: tuple[str, float]) -> float:
142+
def sort_score(pairing: tuple[str, float, float, float]) -> float:
143143
"""helper func for display_branch_eval_summary
144144
takes a tuple pairing and returns just the score
145145
"""
146146
return pairing[1]
147147

148148

149-
def display_models(model, base_model) -> None:
149+
def get_benchmark_max_score(benchmark: Benchmark) -> str:
150+
# total score for Benchmark.MT_BENCH_BRANCH or Benchmark.MT_Bench
151+
max_score = "10.0"
152+
if benchmark in (Benchmark.MMLU_BRANCH, Benchmark.MMLU):
153+
max_score = "1.0"
154+
return max_score
155+
156+
157+
def display_models_and_scores(
158+
benchmark, model, base_model, model_score, base_model_score
159+
) -> None:
150160
"""prints the base_model and model with a header"""
151-
print("## BASE MODEL")
152-
print(base_model)
153-
display_model(model)
161+
max_score = get_benchmark_max_score(benchmark)
154162

163+
base_model_score = round(base_model_score, 2)
164+
model_score = round(model_score, 2)
165+
print("## BASE MODEL (SCORE)")
166+
display_model(base_model, base_model_score, max_score)
167+
print("\n## MODEL (SCORE)")
168+
display_model(model, model_score, max_score)
155169

156-
def display_model(model) -> None:
170+
171+
def display_model(model, model_score, max_score) -> None:
157172
"""prints the given model with a header"""
158-
print("\n## MODEL")
159-
print(model)
173+
model_score = round(model_score, 2)
174+
print(f"{model} ({model_score}/{max_score})")
160175

161176

162177
def display_error_rate(error_rate) -> None:
@@ -167,37 +182,49 @@ def display_error_rate(error_rate) -> None:
167182

168183

169184
def display_branch_eval_summary(
170-
improvements: list[tuple[str, float]],
171-
regressions: list[tuple[str, float]],
172-
no_changes: list[str],
185+
benchmark: Benchmark,
186+
improvements: list[tuple[str, float, float, float]],
187+
regressions: list[tuple[str, float, float, float]],
188+
no_changes: list[tuple[str, float]],
173189
new=None,
174190
):
175191
"""takes in results lists from mt_bench_branch benchmark evaluation
176192
prints out diff between the branches to the user
177193
"""
194+
# total score for MT-BENCH-BRANCH
195+
max_score = get_benchmark_max_score(benchmark)
196+
178197
if len(improvements) > 0:
179198
improvements.sort(key=sort_score, reverse=True)
180-
print("\n### IMPROVEMENTS:")
199+
print(f"\n### IMPROVEMENTS (0.0 to {max_score}):")
181200
for index, improvement in enumerate(improvements):
182-
task, delta = improvement
183-
print(f"{index+1}. {task} (+{delta})")
201+
task, delta, base_score, new_score = improvement
202+
base_score = round(base_score, 2)
203+
new_score = round(new_score, 2)
204+
print(f"{index+1}. {task}: {base_score} -> {new_score} (+{delta})")
184205

185206
if len(regressions) > 0:
186207
regressions.sort(key=sort_score)
187-
print("\n### REGRESSIONS:")
208+
print(f"\n### REGRESSIONS (0.0 to {max_score}):")
188209
for index, regression in enumerate(regressions):
189-
task, delta = regression
190-
print(f"{index+1}. {task} ({delta})")
210+
task, delta, base_score, new_score = regression
211+
base_score = round(base_score, 2)
212+
new_score = round(new_score, 2)
213+
print(f"{index+1}. {task}: {base_score} -> {new_score} ({delta})")
191214

192215
if len(no_changes) > 0:
193-
print("\n### NO CHANGE:")
194-
for index, task in enumerate(no_changes):
195-
print(f"{index+1}. {task}")
216+
print(f"\n### NO CHANGE (0.0 to {max_score}):")
217+
for index, entry in enumerate(no_changes):
218+
task, avg_score = entry
219+
avg_score = round(avg_score, 2)
220+
print(f"{index+1}. {task} ({avg_score})")
196221

197222
if new is not None and len(new) > 0:
198-
print("\n### NEW:")
199-
for index, qna in enumerate(new):
200-
print(f"{index+1}. {qna}")
223+
print(f"\n### NEW (0.0 to {max_score}):")
224+
for index, entry in enumerate(new):
225+
qna, avg_score = entry
226+
avg_score = round(avg_score, 2)
227+
print(f"{index+1}. {qna} ({avg_score})")
201228

202229

203230
def qa_pairs_to_qna_to_avg_scores(qa_pairs: list[dict]) -> dict[str, float]:
@@ -576,13 +603,13 @@ def evaluate(
576603
if server is not None:
577604
server.shutdown()
578605

606+
max_score = get_benchmark_max_score(Benchmark.MT_BENCH)
579607
print("# SKILL EVALUATION REPORT")
580-
display_model(model)
581-
print("\n### AVERAGE:")
582-
print(f"{round(overall_score, 2)} (across {len(qa_pairs)})")
583-
print("\n### TURN ONE:")
608+
print("\n## MODEL (SCORE)")
609+
display_model(model, overall_score, max_score)
610+
print(f"\n### TURN ONE (0.0 to {max_score}):")
584611
print(round(turn_scores[0], 2))
585-
print("\n### TURN TWO:")
612+
print(f"\n### TURN TWO (0.0 to {max_score}):")
586613
turn2_score = turn_scores[1]
587614
if isinstance(turn2_score, float):
588615
turn2_score = round(turn2_score, 2)
@@ -659,46 +686,64 @@ def evaluate(
659686
for i, evaluator in enumerate(evaluators):
660687
branch = branches[i]
661688
print(f"Evaluating answers for branch {branch}...")
662-
judgement = evaluator.judge_answers(
689+
overall_score, qa_pairs, error_rate = evaluator.judge_answers(
663690
api_base, max_workers=max_workers, serving_gpus=effective_gpus
664691
)
665-
666-
if len(judgement) == 3:
667-
qa_pairs = judgement[1]
668-
error_rate = judgement[2]
669-
else:
670-
qa_pairs = judgement[0]
671-
error_rate = judgement[1]
672-
673-
qa_pairs_and_errors.append((qa_pairs, error_rate))
692+
qa_pairs_and_errors.append((overall_score, qa_pairs, error_rate))
674693
finally:
675694
if server is not None:
676695
server.shutdown()
677696

678-
qa_pairs, error_rate = qa_pairs_and_errors[0]
679-
base_qa_pairs, base_error_rate = qa_pairs_and_errors[1]
697+
overall_score, qa_pairs, error_rate = qa_pairs_and_errors[0]
698+
base_overall_score, base_qa_pairs, base_error_rate = qa_pairs_and_errors[1]
680699

681700
qna_to_avg_scores = qa_pairs_to_qna_to_avg_scores(qa_pairs)
682701
base_qna_to_avg_scores = qa_pairs_to_qna_to_avg_scores(base_qa_pairs)
683702

684703
print("# SKILL EVALUATION REPORT\n")
685-
display_models(model, base_model)
704+
display_models_and_scores(
705+
Benchmark.MT_BENCH_BRANCH,
706+
model,
707+
base_model,
708+
overall_score,
709+
base_overall_score,
710+
)
686711

687712
improvements, regressions, no_changes, new_qnas = [], [], [], []
688713
for qna, avg_score in qna_to_avg_scores.items():
689714
base_avg_score = base_qna_to_avg_scores.get(qna)
690715
if base_avg_score is not None:
691716
if avg_score > base_avg_score:
692-
improvements.append((qna, round(avg_score - base_avg_score, 2)))
717+
improvements.append(
718+
(
719+
qna,
720+
round(avg_score - base_avg_score, 2),
721+
base_avg_score,
722+
avg_score,
723+
)
724+
)
693725
elif avg_score == base_avg_score:
694-
no_changes.append(qna)
726+
no_changes.append((qna, avg_score))
695727
else:
696-
regressions.append((qna, round(avg_score - base_avg_score, 2)))
728+
regressions.append(
729+
(
730+
qna,
731+
round(avg_score - base_avg_score, 2),
732+
base_avg_score,
733+
avg_score,
734+
)
735+
)
697736
else:
698-
new_qnas.append((qna))
737+
new_qnas.append((qna, avg_score))
699738

700739
# display summary of evaluation before exiting
701-
display_branch_eval_summary(improvements, regressions, no_changes, new_qnas)
740+
display_branch_eval_summary(
741+
Benchmark.MT_BENCH_BRANCH,
742+
improvements,
743+
regressions,
744+
no_changes,
745+
new_qnas,
746+
)
702747
display_error_rate((error_rate + base_error_rate) / 2)
703748

704749
elif benchmark == Benchmark.MMLU:
@@ -732,12 +777,12 @@ def evaluate(
732777
if server is not None:
733778
server.shutdown()
734779

780+
max_score = get_benchmark_max_score(Benchmark.MMLU)
735781
print("# KNOWLEDGE EVALUATION REPORT")
736-
display_model(model)
737-
print("\n### AVERAGE:")
738-
print(f"{round(overall_score, 2)} (across {len(individual_scores)})\n")
782+
print("\n## MODEL (SCORE)")
783+
display_model(model, overall_score, max_score)
739784

740-
print("### SCORES:")
785+
print(f"\n### SCORES (0.0 to {max_score}):")
741786
for task, score in individual_scores.items():
742787
s = round(score["score"], 2)
743788
print(f"{task} - {s}")
@@ -791,16 +836,13 @@ def evaluate(
791836
base_individual_scores = individual_scores_list[1]
792837

793838
print("# KNOWLEDGE EVALUATION REPORT\n")
794-
display_models(model, base_model)
795-
796-
print("\n### AVERAGE:")
797-
delta = round(overall_score - base_overall_score, 2)
798-
if delta >= 0:
799-
delta_display = f"+{delta}"
800-
else:
801-
delta_display = delta
802-
803-
print(f"{delta_display} (across {len(individual_scores)})")
839+
display_models_and_scores(
840+
Benchmark.MMLU_BRANCH,
841+
model,
842+
base_model,
843+
overall_score,
844+
base_overall_score,
845+
)
804846

805847
improvements, regressions, no_changes = [], [], []
806848
for task, score in individual_scores.items():
@@ -809,14 +851,16 @@ def evaluate(
809851
b_s = base_score["score"]
810852
d = round(s - b_s, 2)
811853
if s > b_s:
812-
improvements.append((task, d))
854+
improvements.append((task, d, b_s, s))
813855
elif b_s > s:
814-
regressions.append((task, d))
856+
regressions.append((task, d, b_s, s))
815857
else:
816-
no_changes.append(task)
858+
no_changes.append((task, s))
817859

818860
# display summary of evaluation before exiting
819-
display_branch_eval_summary(improvements, regressions, no_changes)
861+
display_branch_eval_summary(
862+
Benchmark.MMLU_BRANCH, improvements, regressions, no_changes
863+
)
820864
except EvalError as ee:
821865
print(ee.message)
822866
raise click.exceptions.Exit(1)

0 commit comments

Comments
 (0)