forked from ZhuZhouFan/AlphaQCM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_drl_csi300.py
More file actions
93 lines (82 loc) · 3.88 KB
/
train_drl_csi300.py
File metadata and controls
93 lines (82 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import yaml
import argparse
import torch
from datetime import datetime
from fqf_iqn_qrdqn.agent import QRDQNAgent, IQNAgent, FQFAgent
from alphagen.data.expression import Feature, FeatureType, Ref, StockData
from alphagen_qlib.calculator import QLibStockDataCalculator
from alphagen.models.alpha_pool import AlphaPool
from alphagen.rl.env.wrapper import AlphaEnv
def run(args):
# torch.cuda.set_device(args.cuda)
config_path = os.path.join('config', f'{args.model}.yaml')
with open(config_path) as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
# Create environments.
device = torch.device(f'cuda')
close = Feature(FeatureType.CLOSE)
target = Ref(close, -20) / close - 1
instruments: float = 'csi300'
data_train = StockData(instrument=instruments,
start_time='2010-01-01',
end_time='2019-12-31')
data_valid = StockData(instrument=instruments,
start_time='2020-01-01',
end_time='2020-12-31')
data_test = StockData(instrument=instruments,
start_time='2021-01-01',
end_time='2022-12-31')
train_calculator = QLibStockDataCalculator(data_train, target)
valid_calculator = QLibStockDataCalculator(data_valid, target)
test_calculator = QLibStockDataCalculator(data_test, target)
train_pool = AlphaPool(capacity=args.pool,
calculator=train_calculator,
ic_lower_bound=None,
l1_alpha=5e-3)
train_env = AlphaEnv(pool=train_pool, device=device, print_expr=True)
# Specify the directory to log.
name = args.model
time = datetime.now().strftime("%Y%m%d-%H%M")
if name == 'qrdqn':
log_dir = os.path.join('AlphaQCM_data/csi300_logs',
f"pool_{args.pool}",
f"{name}-seed{args.seed}-{time}-N{config['N']}-lr{config['lr']}-per{config['use_per']}-gamma{config['gamma']}-step{config['multi_step']}")
elif name == 'iqn':
log_dir = os.path.join('AlphaQCM_data/csi300_logs',
f"pool_{args.pool}",
f"{name}-seed{args.seed}-{time}-N{config['K']}-lr{config['lr']}-per{config['use_per']}-gamma{config['gamma']}-step{config['multi_step']}")
elif name == 'fqf':
log_dir = os.path.join('AlphaQCM_data/csi300_logs',
f"pool_{args.pool}",
f"{name}-seed{args.seed}-{time}-N{config['N']}-lr{config['quantile_lr']}-per{config['use_per']}-gamma{config['gamma']}-step{config['multi_step']}")
# Create the agent and run.
if name == 'qrdqn':
agent = QRDQNAgent(env=train_env,
valid_calculator=valid_calculator,
test_calculator=test_calculator,
log_dir=log_dir,
seed=args.seed,
cuda=True, **config)
elif name == 'iqn':
agent = IQNAgent(env=train_env,
valid_calculator=valid_calculator,
test_calculator=test_calculator,
log_dir=log_dir,
seed=args.seed,
cuda=True, **config)
elif name == 'fqf':
agent = FQFAgent(env=train_env,
valid_calculator=valid_calculator,
test_calculator=test_calculator,
log_dir=log_dir,
seed=args.seed,
cuda=True, **config)
agent.run()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='qrdqn')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--pool', type=int, default=20)
args = parser.parse_args()
run(args)