gsyyysg / stockformer Goto Github PK
View Code? Open in Web Editor NEWPyTorch implementation for Paper "StockFormer: Learning Hybrid Trading Machines with Predictive Coding".
PyTorch implementation for Paper "StockFormer: Learning Hybrid Trading Machines with Predictive Coding".
导致的错误如下:
Traceback (most recent call last):
File "/home/wenjh/StockFormer/Transformer/main.py", line 108, in <module>
exp.train(setting)
File "/home/wenjh/StockFormer/Transformer/exp/exp_mae.py", line 164, in train
_,_, output = self.model(enc_inp, enc_inp)
File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wenjh/StockFormer/Transformer/models/transformer.py", line 66, in forward
enc_out = self.enc_embedding(x_enc)
File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wenjh/StockFormer/Transformer/models/embed.py", line 54, in forward
a = self.value_embedding(x)
File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wenjh/StockFormer/Transformer/models/embed.py", line 40, in forward
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 302, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 295, in _conv_forward
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
RuntimeError: Given groups=1, weight of size [128, 96, 3], expected input[32, 10, 4] to have 96 channels, but got 10 channels instead
是否应将train_mae.sh中enc_in 96改为10
absl-py==2.1.0
certifi==2024.6.2
charset-normalizer==3.3.2
cloudpickle==1.6.0
contourpy==1.2.1
cycler==0.12.1
fonttools==4.53.0
greenlet==3.0.3
grpcio==1.64.1
gym==0.19.0
idna==3.7
importlib_metadata==7.1.0
importlib_resources==6.4.0
int-date==0.1.8
joblib==1.4.2
kiwisolver==1.4.5
lxml==5.2.2
Markdown==3.6
MarkupSafe==2.1.5
matplotlib==3.9.0
multitasking==0.0.11
numpy==1.26.4
packaging==24.0
pandas==1.3.5
patsy==0.5.6
pillow==10.3.0
protobuf==4.25.3
psycopg2-binary==2.9.9
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytz==2024.1
requests==2.32.3
scikit-learn==1.0.1
scipy==1.10.1
six==1.16.0
SQLAlchemy==1.4.52
statsmodels==0.13.2
stockstats==0.3.2
tensorboard==2.17.0
tensorboard-data-server==0.7.2
tensorboardX==2.5
threadpoolctl==3.5.0
torch==1.12.0+cu113
torchaudio==0.12.0+cu113
torchvision==0.13.0+cu113
tqdm==4.19.9
typing_extensions==4.12.2
urllib3==2.2.1
Werkzeug==3.0.3
wrapt==1.14.1
wrds==3.1.6
xlrd==2.0.1
yfinance==0.1.67
zipp==3.19.2
[email protected]:/StockFormer/code$ python train_rl.py
The Zen of Python, by Tim Peters
Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than right now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!
generate technical indicator...
Successfully added technical indicators
Stock Dimension: 88, State Space: 88
Initial Env...
Successfully load prediction mode... Transformer/pretrained/csi/Short/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Long/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Short/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Long/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Short/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Long/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Short/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Long/checkpoint.pth
{'batch_size': 32, 'buffer_size': 100000, 'learning_rate': 0.0001, 'learning_starts': 100, 'ent_coef': 'auto_0.1', 'enc_in': 96, 'dec_in': 96, 'c_out_construction': 96, 'd_model': 128, 'd_ff': 256, 'n_heads': 4, 'e_layers': 2, 'd_layers': 1, 'dropout': 0.05, 'transformer_path': 'Transformer/pretrained/csi/mae/checkpoint.pth'}
Using cuda device
Traceback (most recent call last):
File "/StockFormer/code/train_rl.py", line 205, in
model_sac = agent.get_model("maesac",model_kwargs = MAESAC_PARAMS,tensorboard_log=tensorboard_log_dir, seed=fix_seed)
File "/StockFormer/code/MySAC/models/DRLAgent.py", line 135, in get_model
model = MODELS[model_name](
File "/StockFormer/code/MySAC/SAC/MAE_SAC.py", line 163, in init
self._setup_model()
File "/StockFormer/code/MySAC/SAC/MAE_SAC.py", line 195, in _setup_model
super(SAC, self)._setup_model()
File "/StockFormer/code/MySAC/SAC/off_policy_algorithm.py", line 179, in _setup_model
self.set_random_seed(self.seed)
File "/StockFormer/code/stable_baselines3/common/base_class.py", line 567, in set_random_seed
self.env.seed(seed)
File "/StockFormer/code/stable_baselines3/common/vec_env/base_vec_env.py", line 278, in seed
return self.venv.seed(seed)
File "/StockFormer/code/stable_baselines3/common/vec_env/dummy_vec_env.py", line 56, in seed
seeds.append(env.seed(seed + idx))
AttributeError: 'StockTradingEnv' object has no attribute 'seed'. Did you mean: '_seed'?
Traceback (most recent call last):
File "D:\Code\StockFormer-main\code\Transformer\main.py", line 88, in
data = data_type_dict[args.data_type](
File "D:\Code\StockFormer-main\code\Transformer\data\stock_data_handle.py", line 35, in init
self.read_data()
File "D:\Code\StockFormer-main\code\Transformer\data\stock_data_handle.py", line 62, in read_data
df = fe.preprocess_data(df)
File "D:\Code\StockFormer-main\code\Transformer\utils\preprocess.py", line 86, in preprocess_data
df = self.add_technical_indicator(df)
File "D:\Code\StockFormer-main\code\Transformer\utils\preprocess.py", line 212, in add_technical_indicator
indicator_df[["tic", "date", indicator]], on=["tic", "date"], how="left"
File "C:\Users\lenovo\anaconda3\lib\site-packages\pandas\core\frame.py", line 3511, in getitem
indexer = self.columns._get_indexer_strict(key, "columns")[1]
File "C:\Users\lenovo\anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 5782, in _get_indexer_strict
self._raise_if_missing(keyarr, indexer, axis_name)
File "C:\Users\lenovo\anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 5842, in _raise_if_missing
raise KeyError(f"None of [{key}] are in the [{axis_name}]")
KeyError: "None of [Index(['tic', 'date', 'cci_30'], dtype='object')] are in the [columns]"
代码中TRM模型的反向传播是定义在哪里的?如果是用的torch本身的反向传播,那优化前向传播以后是否还需要对应修改?
savefig根本没有index参数,你是不是拷贝了前面那段代码,忘了改了?
/content/gdrive/MyDrive/Colab/StockFormer-main/code/stable_baselines3/common/save_util.py:166: UserWarning: Could not deserialize object action_space. Consider using custom_objects argument to replace this object.
warnings.warn(
Successfully initialize transformer model...
/content/gdrive/MyDrive/Colab/StockFormer-main/code/stable_baselines3/common/buffers.py:219: UserWarning: This system does not have apparently enough memory to store the complete replay buffer 37.28GB > 1.61GB
warnings.warn(
Traceback (most recent call last):
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/MySAC/models/DRLAgent.py", line 185, in DRL_prediction_load_from_file
model = MODELS[model_name].load(cwd)
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/stable_baselines3/common/base_class.py", line 721, in load
model.set_parameters(params, exact_match=True, device=device)
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/stable_baselines3/common/base_class.py", line 630, in set_parameters
attr.load_state_dict(params[name], strict=exact_match)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SACPolicy:
size mismatch for actor.mu.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.mu.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for actor.log_std.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.log_std.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for critic.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/train_rl.py", line 224, in
results = DRLAgent.DRL_prediction_load_from_file(model_name='maesac',environment=test_trade_gym, cwd=model_path)
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/MySAC/models/DRLAgent.py", line 188, in DRL_prediction_load_from_file
raise ValueError("Fail to load agent!")
ValueError: Fail to load agent!
报错:在第一阶段训练完成后,选出模型用于train_rl.py时,报错。train_rl用作者训练好的模型,运行正常,不会报错
the code have some dependency issue, I have fixed it. please check my fork repo
https://github.com/xbkaishui/StockFormer
Traceback (most recent call last):
File "D:\downloads\StockFormer-main\StockFormer-main\code\train_rl.py", line 210, in
trained_sac = agent.train_model(model=model_sac,
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\models\DRLAgent.py", line 151, in train_model
model = model.learn(
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\MAE_SAC.py", line 364, in learn
return super(SAC, self).learn(
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\off_policy_algorithm.py", line 352, in learn
rollout = self.collect_rollouts(
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\off_policy_algorithm.py", line 584, in collect_rollouts
if callback.on_step() is False:
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 88, in on_step
return self._on_step()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 192, in _on_step
continue_training = callback.on_step() and continue_training
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 88, in on_step
return self._on_step()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 379, in _on_step
episode_rewards, episode_lengths = evaluate_policy(
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\evaluation.py", line 86, in evaluate_policy
observations, rewards, dones, infos = env.step(actions)
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\base_vec_env.py", line 163, in step
return self.step_wait()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\vec_monitor.py", line 76, in step_wait
obs, rewards, dones, infos = self.venv.step_wait()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\dummy_vec_env.py", line 43, in step_wait
obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
File "D:\downloads\StockFormer-main\StockFormer-main\code\envs\env_stocktrading_hybrid_control.py", line 279, in step
plt.savefig(
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\pyplot.py", line 1119, in savefig
res = fig.savefig(*args, **kwargs) # type: ignore[func-returns-value]
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\figure.py", line 3390, in savefig
self.canvas.print_figure(fname, **kwargs)
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2193, in print_figure
result = print_method(
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2043, in
print_method = functools.wraps(meth)(lambda *args, **kwargs: meth(
TypeError: print_png() got an unexpected keyword argument 'index'
RuntimeError: Error(s) in loading state_dict for SACPolicy:
size mismatch for actor.mu.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.mu.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for actor.log_std.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.log_std.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for critic.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/mnt/sda/DL/AI4F/StockFormer/code/train_rl.py", line 226, in
results = DRLAgent.DRL_prediction_load_from_file(model_name='maesac',environment=test_trade_gym, cwd=model_path)
File "/mnt/sda/DL/AI4F/StockFormer/code/MySAC/models/DRLAgent.py", line 188, in DRL_prediction_load_from_file
raise ValueError("Fail to load agent!")
ValueError: Fail to load agent!
虽然拿到了作者提供的数据,但是price列是哪个值?以及open、high、low、close这些是做了归一化吗?
Hello, when I checked the code of stock_data_handle.py
, I got confused about the __getitem__
function of class DatasetStock_PRED
because in the code you index the self.label
with index
instead of position
, and I don't know why. I think it may lead to the return ratio generated from the first and second/ fifth day in seq_x
being mistaken for the return ratio of the first and second/ fifth day in the future.
To prove what I found, I run sh train_pred_long.sh
as an example and add some codes in the train
funtion of Exp_pred
in exp_pred.py
as shown in the image below:
It seems that the variable c
which is a return ratio simply calculated from batch_x1
is exactly the same as batch_y
. So I think in the code, there is an alignment problem between the training features and the labels
Is this a mistake? If what I found is True, it will make the training process unreliable and may affect the conclusion. Could anyone help me? Thanks!
测试没问题,但是研究了半天。也看不到csv的处理代码,不知道csi300的数据csv怎么产生的,只看出来price是收盘价,有谁知道csv的open,close等值怎么根据price处理的?
在执行:
results = DRLAgent.DRL_prediction_load_from_file(model_name='maesac',environment=test_trade_gym, cwd=model_path)
的时候报错:
RuntimeError: Error(s) in loading state_dict for SACPolicy:
size mismatch for actor.mu.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.mu.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for actor.log_std.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.log_std.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for critic.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
请问这是什么原因呢?
尽管已经提供了数据,但是如何处理原始OHLC数据的到网络输入的数据,能提供思路或者脚本吗 Although the data has been provided, can you provide ideas or scripts on how to process the raw OHLC data into the network input data
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.