Giter Club home page Giter Club logo

stockformer's People

Contributors

gsyyysg avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

stockformer's Issues

测试脚本train_mae.sh中enc_in有错

导致的错误如下:

Traceback (most recent call last):
  File "/home/wenjh/StockFormer/Transformer/main.py", line 108, in <module>
    exp.train(setting)
  File "/home/wenjh/StockFormer/Transformer/exp/exp_mae.py", line 164, in train
    _,_, output = self.model(enc_inp, enc_inp)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wenjh/StockFormer/Transformer/models/transformer.py", line 66, in forward
    enc_out = self.enc_embedding(x_enc)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wenjh/StockFormer/Transformer/models/embed.py", line 54, in forward
    a = self.value_embedding(x)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wenjh/StockFormer/Transformer/models/embed.py", line 40, in forward
    x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 302, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 295, in _conv_forward
    return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
RuntimeError: Given groups=1, weight of size [128, 96, 3], expected input[32, 10, 4] to have 96 channels, but got 10 channels instead

是否应将train_mae.sh中enc_in 96改为10

存一下环境

absl-py==2.1.0
certifi==2024.6.2
charset-normalizer==3.3.2
cloudpickle==1.6.0
contourpy==1.2.1
cycler==0.12.1
fonttools==4.53.0
greenlet==3.0.3
grpcio==1.64.1
gym==0.19.0
idna==3.7
importlib_metadata==7.1.0
importlib_resources==6.4.0
int-date==0.1.8
joblib==1.4.2
kiwisolver==1.4.5
lxml==5.2.2
Markdown==3.6
MarkupSafe==2.1.5
matplotlib==3.9.0
multitasking==0.0.11
numpy==1.26.4
packaging==24.0
pandas==1.3.5
patsy==0.5.6
pillow==10.3.0
protobuf==4.25.3
psycopg2-binary==2.9.9
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytz==2024.1
requests==2.32.3
scikit-learn==1.0.1
scipy==1.10.1
six==1.16.0
SQLAlchemy==1.4.52
statsmodels==0.13.2
stockstats==0.3.2
tensorboard==2.17.0
tensorboard-data-server==0.7.2
tensorboardX==2.5
threadpoolctl==3.5.0
torch==1.12.0+cu113
torchaudio==0.12.0+cu113
torchvision==0.13.0+cu113
tqdm==4.19.9
typing_extensions==4.12.2
urllib3==2.2.1
Werkzeug==3.0.3
wrapt==1.14.1
wrds==3.1.6
xlrd==2.0.1
yfinance==0.1.67
zipp==3.19.2

python train_rl.py 有错

[email protected]:/StockFormer/code$ python train_rl.py
The Zen of Python, by Tim Peters

Beautiful is better than ugly.
Explicit is better than implicit.
Simple is better than complex.
Complex is better than complicated.
Flat is better than nested.
Sparse is better than dense.
Readability counts.
Special cases aren't special enough to break the rules.
Although practicality beats purity.
Errors should never pass silently.
Unless explicitly silenced.
In the face of ambiguity, refuse the temptation to guess.
There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.
Now is better than never.
Although never is often better than right now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!
generate technical indicator...
Successfully added technical indicators
Stock Dimension: 88, State Space: 88
Initial Env...
Successfully load prediction mode... Transformer/pretrained/csi/Short/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Long/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Short/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Long/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Short/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Long/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Short/checkpoint.pth
Successfully load prediction mode... Transformer/pretrained/csi/Long/checkpoint.pth
{'batch_size': 32, 'buffer_size': 100000, 'learning_rate': 0.0001, 'learning_starts': 100, 'ent_coef': 'auto_0.1', 'enc_in': 96, 'dec_in': 96, 'c_out_construction': 96, 'd_model': 128, 'd_ff': 256, 'n_heads': 4, 'e_layers': 2, 'd_layers': 1, 'dropout': 0.05, 'transformer_path': 'Transformer/pretrained/csi/mae/checkpoint.pth'}
Using cuda device
Traceback (most recent call last):
File "/StockFormer/code/train_rl.py", line 205, in
model_sac = agent.get_model("maesac",model_kwargs = MAESAC_PARAMS,tensorboard_log=tensorboard_log_dir, seed=fix_seed)
File "/StockFormer/code/MySAC/models/DRLAgent.py", line 135, in get_model
model = MODELS[model_name](
File "/StockFormer/code/MySAC/SAC/MAE_SAC.py", line 163, in init
self._setup_model()
File "/StockFormer/code/MySAC/SAC/MAE_SAC.py", line 195, in _setup_model
super(SAC, self)._setup_model()
File "/StockFormer/code/MySAC/SAC/off_policy_algorithm.py", line 179, in _setup_model
self.set_random_seed(self.seed)
File "/StockFormer/code/stable_baselines3/common/base_class.py", line 567, in set_random_seed
self.env.seed(seed)
File "/StockFormer/code/stable_baselines3/common/vec_env/base_vec_env.py", line 278, in seed
return self.venv.seed(seed)
File "/StockFormer/code/stable_baselines3/common/vec_env/dummy_vec_env.py", line 56, in seed
seeds.append(env.seed(seed + idx))
AttributeError: 'StockTradingEnv' object has no attribute 'seed'. Did you mean: '_seed'?

运行main.py时,add_technical_indicator在计算cci_30指标时抛出异常:float division by zero

Traceback (most recent call last):
File "D:\Code\StockFormer-main\code\Transformer\main.py", line 88, in
data = data_type_dict[args.data_type](
File "D:\Code\StockFormer-main\code\Transformer\data\stock_data_handle.py", line 35, in init
self.read_data()
File "D:\Code\StockFormer-main\code\Transformer\data\stock_data_handle.py", line 62, in read_data
df = fe.preprocess_data(df)
File "D:\Code\StockFormer-main\code\Transformer\utils\preprocess.py", line 86, in preprocess_data
df = self.add_technical_indicator(df)
File "D:\Code\StockFormer-main\code\Transformer\utils\preprocess.py", line 212, in add_technical_indicator
indicator_df[["tic", "date", indicator]], on=["tic", "date"], how="left"
File "C:\Users\lenovo\anaconda3\lib\site-packages\pandas\core\frame.py", line 3511, in getitem
indexer = self.columns._get_indexer_strict(key, "columns")[1]
File "C:\Users\lenovo\anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 5782, in _get_indexer_strict
self._raise_if_missing(keyarr, indexer, axis_name)
File "C:\Users\lenovo\anaconda3\lib\site-packages\pandas\core\indexes\base.py", line 5842, in _raise_if_missing
raise KeyError(f"None of [{key}] are in the [{axis_name}]")
KeyError: "None of [Index(['tic', 'date', 'cci_30'], dtype='object')] are in the [columns]"

反向传播是如何做的?

代码中TRM模型的反向传播是定义在哪里的?如果是用的torch本身的反向传播,那优化前向传播以后是否还需要对应修改?

强化学习阶段加载模型报错“Fail to load agent!”,是因为内存不够还是需要修改代码啊?

/content/gdrive/MyDrive/Colab/StockFormer-main/code/stable_baselines3/common/save_util.py:166: UserWarning: Could not deserialize object action_space. Consider using custom_objects argument to replace this object.
warnings.warn(
Successfully initialize transformer model...
/content/gdrive/MyDrive/Colab/StockFormer-main/code/stable_baselines3/common/buffers.py:219: UserWarning: This system does not have apparently enough memory to store the complete replay buffer 37.28GB > 1.61GB
warnings.warn(
Traceback (most recent call last):
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/MySAC/models/DRLAgent.py", line 185, in DRL_prediction_load_from_file
model = MODELS[model_name].load(cwd)
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/stable_baselines3/common/base_class.py", line 721, in load
model.set_parameters(params, exact_match=True, device=device)
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/stable_baselines3/common/base_class.py", line 630, in set_parameters
attr.load_state_dict(params[name], strict=exact_match)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SACPolicy:
size mismatch for actor.mu.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.mu.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for actor.log_std.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.log_std.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for critic.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/train_rl.py", line 224, in
results = DRLAgent.DRL_prediction_load_from_file(model_name='maesac',environment=test_trade_gym, cwd=model_path)
File "/content/gdrive/MyDrive/Colab/StockFormer-main/code/MySAC/models/DRLAgent.py", line 188, in DRL_prediction_load_from_file
raise ValueError("Fail to load agent!")
ValueError: Fail to load agent!

运行train_rl.py时,在train_model这里报错

Traceback (most recent call last):
File "D:\downloads\StockFormer-main\StockFormer-main\code\train_rl.py", line 210, in
trained_sac = agent.train_model(model=model_sac,
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\models\DRLAgent.py", line 151, in train_model
model = model.learn(
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\MAE_SAC.py", line 364, in learn
return super(SAC, self).learn(
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\off_policy_algorithm.py", line 352, in learn
rollout = self.collect_rollouts(
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\off_policy_algorithm.py", line 584, in collect_rollouts
if callback.on_step() is False:
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 88, in on_step
return self._on_step()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 192, in _on_step
continue_training = callback.on_step() and continue_training
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 88, in on_step
return self._on_step()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 379, in _on_step
episode_rewards, episode_lengths = evaluate_policy(
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\evaluation.py", line 86, in evaluate_policy
observations, rewards, dones, infos = env.step(actions)
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\base_vec_env.py", line 163, in step
return self.step_wait()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\vec_monitor.py", line 76, in step_wait
obs, rewards, dones, infos = self.venv.step_wait()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\dummy_vec_env.py", line 43, in step_wait
obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
File "D:\downloads\StockFormer-main\StockFormer-main\code\envs\env_stocktrading_hybrid_control.py", line 279, in step
plt.savefig(
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\pyplot.py", line 1119, in savefig
res = fig.savefig(*args, **kwargs) # type: ignore[func-returns-value]
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\figure.py", line 3390, in savefig
self.canvas.print_figure(fname, **kwargs)
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2193, in print_figure
result = print_method(
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2043, in
print_method = functools.wraps(meth)(lambda *args, **kwargs: meth(
TypeError: print_png() got an unexpected keyword argument 'index'

When the program was running to load zipfile, an error occurred. It is werid.

RuntimeError: Error(s) in loading state_dict for SACPolicy:
size mismatch for actor.mu.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.mu.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for actor.log_std.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.log_std.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for critic.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/mnt/sda/DL/AI4F/StockFormer/code/train_rl.py", line 226, in
results = DRLAgent.DRL_prediction_load_from_file(model_name='maesac',environment=test_trade_gym, cwd=model_path)
File "/mnt/sda/DL/AI4F/StockFormer/code/MySAC/models/DRLAgent.py", line 188, in DRL_prediction_load_from_file
raise ValueError("Fail to load agent!")
ValueError: Fail to load agent!

数据处理是如何做的?

虽然拿到了作者提供的数据,但是price列是哪个值?以及open、high、low、close这些是做了归一化吗?

A mistake (maybe) in stock_data_handle.py

Hello, when I checked the code of stock_data_handle.py, I got confused about the __getitem__ function of class DatasetStock_PRED because in the code you index the self.label with index instead of position, and I don't know why. I think it may lead to the return ratio generated from the first and second/ fifth day in seq_x being mistaken for the return ratio of the first and second/ fifth day in the future.
7c584df9226f3e7ebe609e64bedec6f
To prove what I found, I run sh train_pred_long.sh as an example and add some codes in the train funtion of Exp_pred in exp_pred.py as shown in the image below:
ac55d99d03b00204c4903d6a7a43bcf
It seems that the variable c which is a return ratio simply calculated from batch_x1 is exactly the same as batch_y. So I think in the code, there is an alignment problem between the training features and the labels
a1cc03f1e3767d51c995b2410bc41f5
Is this a mistake? If what I found is True, it will make the training process unreliable and may affect the conclusion. Could anyone help me? Thanks!

回测时加载模型checkpoint和初始化actor critic维度不同

在执行:
results = DRLAgent.DRL_prediction_load_from_file(model_name='maesac',environment=test_trade_gym, cwd=model_path)
的时候报错:
RuntimeError: Error(s) in loading state_dict for SACPolicy:
size mismatch for actor.mu.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.mu.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for actor.log_std.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]).
size mismatch for actor.log_std.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]).
size mismatch for critic.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
size mismatch for critic_target.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]).
请问这是什么原因呢?

数据处理的脚本Data PreProcess Method

尽管已经提供了数据,但是如何处理原始OHLC数据的到网络输入的数据,能提供思路或者脚本吗 Although the data has been provided, can you provide ideas or scripts on how to process the raw OHLC data into the network input data

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.