-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
88 lines (82 loc) · 3.58 KB
/
train.py
File metadata and controls
88 lines (82 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Train all v5 models: tpgsr bicubic + real-degraded pairs."""
import sys
import os
import traceback
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.training.finetune_education import main
JOBS = [
{
"name": "x3-v5-tpgsr",
"args": [
"finetune_education",
"--train_csv", "datasets/manifests/educational_x3_v5/train_pairs.csv",
"--val_csv", "datasets/manifests/educational_x3_v5/val_pairs.csv",
"--output_dir", "models/span/education-finetuned/x3-v5-tpgsr",
"--scale", "3", "--use_span", "--channels", "48", "--num_blocks", "6",
"--epochs", "50", "--batch_size", "8", "--lr", "1e-4",
"--lr_crop_size", "64", "--tpgsr_gp_weight", "0.1", "--text_bias",
"--init_checkpoint", "models/span/education-finetuned/x3-v4/best_model.pt",
],
},
{
"name": "x2-v3-tpgsr",
"args": [
"finetune_education",
"--train_csv", "datasets/manifests/educational_x2_v3/train_pairs.csv",
"--val_csv", "datasets/manifests/educational_x2_v3/val_pairs.csv",
"--output_dir", "models/span/education-finetuned/x2-v3-tpgsr",
"--scale", "2", "--use_span", "--channels", "48", "--num_blocks", "6",
"--epochs", "50", "--batch_size", "8", "--lr", "1e-4",
"--lr_crop_size", "64", "--tpgsr_gp_weight", "0.1", "--text_bias",
"--init_checkpoint", "models/span/education-finetuned/x2-v2/best_model.pt",
],
},
{
"name": "x3-real",
"args": [
"finetune_education",
"--train_csv", "datasets/manifests/educational_real_x3/train_pairs.csv",
"--val_csv", "datasets/manifests/educational_real_x3/val_pairs.csv",
"--output_dir", "models/span/education-finetuned/x3-real",
"--scale", "3", "--use_span", "--channels", "48", "--num_blocks", "6",
"--epochs", "50", "--batch_size", "8", "--lr", "1e-4",
"--lr_crop_size", "64", "--tpgsr_gp_weight", "0.1", "--text_bias",
"--init_checkpoint", "models/span/education-finetuned/x3-v4/best_model.pt",
],
},
{
"name": "x2-real",
"args": [
"finetune_education",
"--train_csv", "datasets/manifests/educational_real_x2/train_pairs.csv",
"--val_csv", "datasets/manifests/educational_real_x2/val_pairs.csv",
"--output_dir", "models/span/education-finetuned/x2-real",
"--scale", "2", "--use_span", "--channels", "48", "--num_blocks", "6",
"--epochs", "50", "--batch_size", "8", "--lr", "1e-4",
"--lr_crop_size", "64", "--tpgsr_gp_weight", "0.1", "--text_bias",
"--init_checkpoint", "models/span/education-finetuned/x2-v2/best_model.pt",
],
},
]
if __name__ == "__main__":
os.chdir(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
results = {}
for job in JOBS:
print(f"\n{'='*60}")
print(f"TRAINING: {job['name']}")
print(f"{'='*60}")
sys.argv = job["args"]
try:
ret = main()
results[job["name"]] = ret
print(f"\n>>> {job['name']} finished with code {ret}")
except Exception:
traceback.print_exc()
results[job["name"]] = -1
print(f"\n>>> {job['name']} FAILED")
print(f"\n{'='*60}")
print("ALL TRAINING COMPLETE")
print(f"{'='*60}")
for name, code in results.items():
status = "OK" if code == 0 else "FAILED"
print(f" {name}: {status} (code {code})")