-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbenchmark.py
More file actions
145 lines (128 loc) · 5.63 KB
/
benchmark.py
File metadata and controls
145 lines (128 loc) · 5.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""OCR benchmark v5 wrapper for held-out SR model comparisons."""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from models.training.ocr_benchmark import run_heldout_benchmark
def collect_benchmarks() -> list[tuple[str, str, str, str, str, str]]:
"""Return the default benchmark matrix used in the project."""
benchmarks = [
(
"models/span/education-finetuned/x3-v4/best_model.pt",
"datasets/manifests/_archive/march9_heldout_v2/heldout_bicx3.csv",
"outputs/benchmarks/ocr_v5_x3v4",
"benchmarks/results/ocr-v5-x3v4-comparison.csv",
"benchmarks/results/ocr-v5-x3v4-summary.json",
"x3-v4 (bicubic, baseline)",
),
(
"models/span/education-finetuned/x3-v5-tpgsr/best_model.pt",
"datasets/manifests/_archive/march9_heldout_v2/heldout_bicx3.csv",
"outputs/benchmarks/ocr_v5_x3v5tpgsr",
"benchmarks/results/ocr-v5-x3v5tpgsr-comparison.csv",
"benchmarks/results/ocr-v5-x3v5tpgsr-summary.json",
"x3-v5-tpgsr (bicubic+TPGSR)",
),
(
"models/span/education-finetuned/x2-v2/best_model.pt",
"datasets/manifests/_archive/march9_heldout_v2/heldout_bicx2.csv",
"outputs/benchmarks/ocr_v5_x2v2",
"benchmarks/results/ocr-v5-x2v2-comparison.csv",
"benchmarks/results/ocr-v5-x2v2-summary.json",
"x2-v2 (bicubic, baseline)",
),
(
"models/span/education-finetuned/x2-v3-tpgsr/best_model.pt",
"datasets/manifests/_archive/march9_heldout_v2/heldout_bicx2.csv",
"outputs/benchmarks/ocr_v5_x2v3tpgsr",
"benchmarks/results/ocr-v5-x2v3tpgsr-comparison.csv",
"benchmarks/results/ocr-v5-x2v3tpgsr-summary.json",
"x2-v3-tpgsr (bicubic+TPGSR)",
),
]
optional_runs = [
(
"models/span/education-finetuned/x3-real/best_model.pt",
"datasets/manifests/_archive/march9_heldout_v2/heldout_real_x3.csv",
"outputs/benchmarks/ocr_v5_x3real",
"benchmarks/results/ocr-v5-x3real-comparison.csv",
"benchmarks/results/ocr-v5-x3real-summary.json",
"x3-real (legacy name, synthetic LR_240)",
),
(
"models/span/education-finetuned/x3-synth-plus-ppt/best_model.pt",
"datasets/manifests/_archive/march9_heldout_v2/heldout_real_x3.csv",
"outputs/benchmarks/ocr_v5_x3synth_plus_ppt",
"benchmarks/results/ocr-v5-x3synth-plus-ppt-comparison.csv",
"benchmarks/results/ocr-v5-x3synth-plus-ppt-summary.json",
"x3-synth-plus-ppt (legacy synthetic LR_240 + PPT slides)",
),
(
"models/span/education-finetuned/x2-real/best_model.pt",
"datasets/manifests/_archive/march9_heldout_v2/heldout_real_x2.csv",
"outputs/benchmarks/ocr_v5_x2real",
"benchmarks/results/ocr-v5-x2real-comparison.csv",
"benchmarks/results/ocr-v5-x2real-summary.json",
"x2-real (legacy name, synthetic LR_360)",
),
(
"models/span/education-finetuned/x2-synth-plus-ppt/best_model.pt",
"datasets/manifests/_archive/march9_heldout_v2/heldout_real_x2.csv",
"outputs/benchmarks/ocr_v5_x2synth_plus_ppt",
"benchmarks/results/ocr-v5-x2synth-plus-ppt-comparison.csv",
"benchmarks/results/ocr-v5-x2synth-plus-ppt-summary.json",
"x2-synth-plus-ppt (legacy synthetic LR_360 + PPT slides)",
),
]
for benchmark in optional_runs:
if os.path.exists(benchmark[0]):
benchmarks.append(benchmark)
return benchmarks
def print_summary_table(summaries: list[dict[str, object]]) -> None:
"""Print a compact comparison table."""
print("")
print("=" * 70)
print("FINAL COMPARISON - ALL MODELS (March 9 heldout)")
print("=" * 70)
print(f"{'Model':<32} {'OCR Conf':>9} {'Gain':>7} {'LPIPS':>7} {'CER':>7}")
print("-" * 70)
for summary in summaries:
lpips_str = (
f"{float(summary['sr_lpips']):.4f}" if "sr_lpips" in summary else "N/A"
)
cer_str = f"{float(summary['sr_cer']):.4f}" if "sr_cer" in summary else "N/A"
print(
f"{str(summary['model']):<32} "
f"{float(summary['sr_conf']):>9.2f} "
f"{float(summary['gain_vs_baseline']):>+7.2f} "
f"{lpips_str:>7} {cer_str:>7}"
)
def main() -> int:
"""Run the project benchmark matrix through the shared helper."""
os.chdir(REPO_ROOT)
summaries: list[dict[str, object]] = []
for checkpoint, pairs_csv, output_dir, results_csv, results_json, label in collect_benchmarks():
if not os.path.exists(checkpoint):
print(f"\nSKIPPING {label} - checkpoint not found: {checkpoint}")
continue
summary = run_heldout_benchmark(
checkpoint,
pairs_csv,
output_dir,
label,
results_csv=results_csv,
results_json=results_json,
)
summaries.append(summary)
print_summary_table(summaries)
combined_path = REPO_ROOT / "benchmarks" / "results" / "ocr-v5-combined-summary.json"
combined_path.parent.mkdir(parents=True, exist_ok=True)
with open(combined_path, "w", encoding="utf-8") as handle:
json.dump(summaries, handle, indent=2)
print(f"\nCombined results saved to {combined_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())