Giter Club home page Giter Club logo

temporal-fusion-transformer's Introduction

This is an implementation of the Temporal Fusion Transformer network architecure to predict the future bitcoin market price
Running the gen_csv_year(year, symbol, interval) in DownloadData.ipynb will download past price data using the binance api

The training set consists of 2018 and 2019, 2020 is used for testing

Lets start with the necessary imports, as well as matplotlib for visualisation purposes

import torch
import torch.nn as nn
from network import *
from data import *
import pandas as pd
%matplotlib notebook
import matplotlib.pyplot as plt
import math
from mpl_finance import candlestick_ohlc

Next we define which columns are used as continuous and discrete input, as well as prediction targets.

continuous_columns = ['Open', 'High', 'Low', 'Close']
discrete_columns = ['Hour']#, 'Day', 'Month']
target_columns = ['Close']

Load the bitcoin data into memory

print("Loading : ")
btc_data = load_data(['2018', '2019'], 'BTCUSDT', continuous_columns, '5m')
btc_test_data = load_data(['2020'], 'BTCUSDT', continuous_columns, interval = '5m')
Loading : 
done
done

Next we define the hyperparameters, more details can be found in the temporal fusion transformer paper

#input data shape
n_variables_past_continuous = 4
n_variables_future_continuous = 0
n_variables_past_discrete = [24]#, 31, 12]
n_variables_future_discrete = [24]#, 31, 12]

#hyperparams
batch_size = 160
test_batch_size = 160
n_tests = 25
dim_model = 160
n_lstm_layers = 4
n_attention_layers = 3
n_heads = 6

quantiles = torch.tensor([0.1, 0.5, 0.9]).float().type(torch.cuda.FloatTensor)

past_seq_len = 80
future_seq_len = 15

Either load model from a checkpoint or initialise a new one

load_model = True
path = "model_100000.pt"

#initialise
t = TFN(n_variables_past_continuous, n_variables_future_continuous, 
            n_variables_past_discrete, n_variables_future_discrete, dim_model,
            n_quantiles = quantiles.shape[0], dropout_r = 0.2,
            n_attention_layers = n_attention_layers,n_lstm_layers = n_lstm_layers, n_heads = n_heads).cuda()
optimizer = torch.optim.Adam(t.parameters(), lr=0.0005)

#try to load from checkpoint
if load_model:
    checkpoint = torch.load(path)
    t = checkpoint['model_state']
    optimizer.load_state_dict(checkpoint['optimizer_state'])
    losses = checkpoint['losses']
    test_losses = checkpoint['test_losses']
    print("Loaded model from checkpoint")
else:    
    losses = []
    test_losses = []
    print("No checkpoint loaded, initialising model")


#losses = []
Loaded model from checkpoint

define generators for training and test sets

btc_gen = get_batches(btc_data, past_seq_len, 
                future_seq_len, continuous_columns, discrete_columns, 
                target_columns, batch_size = batch_size)

test_btc_gen = get_batches(btc_test_data, past_seq_len, 
            future_seq_len, continuous_columns, discrete_columns, 
            target_columns, batch_size = batch_size, norm = btc_data)

Now lets begin the training process First we create a figure for data visualastion

The network is saved periodically. Therefore overtraining is not a concern, as we can look back and pick the iteration with the best test set performance

fig = plt.figure()
ax = fig.add_subplot(411)
ax1 = fig.add_subplot(412)
ax2 = fig.add_subplot(413)
ax3 = fig.add_subplot(414)
plt.ion()


fig.canvas.draw()
fig.show()

steps = 200000
for e in range(steps):
    #run model against test set every 50 batches
    if(e % 50 == 0):
        
        t.eval()
        m_test_losses = []
        for i in range(n_tests):
            test_loss,_ , _, _ = forward_pass(t, test_btc_gen, test_batch_size, quantiles)
            m_test_losses.append(test_loss.cpu().detach().numpy())
            t.train()
        
        test_losses.append(np.array(m_test_losses).mean())
        
    #save model every 400 batches
    if(e % 400 == 0):
        torch.save({'model_state' : t,
                    'optimizer_state': optimizer.state_dict(),
                   'losses' : losses, 'test_losses' : test_losses} , "model_{}.pt".format(len(losses)))
        
    #forward pass
    optimizer.zero_grad()
    loss, net_out, vs_weights, given_data = forward_pass(t, btc_gen, batch_size, quantiles)
    net_out = net_out.cpu().detach()[0]
    
    #backwards pass
    losses.append(loss.cpu().detach().numpy())
    loss.backward()
    optimizer.step()
    
    #loss graphs
    fig.tight_layout(pad = 0.1)
    ax.clear()
    ax.title.set_text("Training loss")
    ax.plot(losses[250:])
    
    ax1.clear()
    ax1.title.set_text("Test loss")
    ax1.plot(test_losses[5:]) 
    
    #compare network out put and data
    ax2.clear()
    ax2.title.set_text("Network output comparison")
    c = given_data[0][0].cpu()
    a = torch.arange(-past_seq_len, 0).unsqueeze(-1).unsqueeze(-1).float()
    c = torch.cat((a,c), dim = 1)
    candlestick_ohlc(ax2, c.squeeze(), colorup = "green", colordown = "red")

    ax2.plot(net_out[:,0], color = "red")
    ax2.plot(net_out[:,1], color = "blue")
    ax2.plot(net_out[:,2], color = "red")
    ax2.plot(given_data[3].cpu().detach().numpy()[0], label = "target", color = "orange")

    #visualise variable selection weights
    vs_weights = torch.mean(torch.mean(vs_weights, dim = 0), dim = 0).squeeze()
    vs_weights = vs_weights.cpu().detach().numpy()
    ax3.clear()
    ax3.title.set_text("Variable Selection Weights")
    plt.xticks(rotation=-30)
    x = ['Open', 'High', 'Low', 'Close', 'Hour']
    ax3.bar(x = x, height = vs_weights)
    fig.canvas.draw()
    
    del loss
    del net_out
    del vs_weights
    del given_data
    if e >= 2:
        break


The first two graphs simply represent training and test losses respectively
The third graph shows given data in candlestick form, target data in orange, and the networks best guess in blue. Red lines represent 90% and 10% quantiles

The final graph shows variable selection weights, a feature of temporal fusion networks showing how much importance is attributed to each inputFinally lets put the network into evaluation mode and visualise some test set comparisons

#Draw test cases
fig = plt.figure()
axes = []
batch_size_ = 4

for i in range(batch_size_):
    axes.append(fig.add_subplot(411 + i))
    
test_btc_gen = get_batches(btc_test_data, past_seq_len, 
            future_seq_len, continuous_columns, discrete_columns, 
            target_columns, batch_size = batch_size_, norm = btc_data)

loss, net_out, vs_weights, given_data = forward_pass(t, test_btc_gen, batch_size_, quantiles)
net_out = net_out.cpu().detach()
t.eval()
for idx, a in enumerate(axes):
    a.clear()
    
    c = given_data[0][idx].cpu()
    
    b = torch.arange(-past_seq_len, 0).unsqueeze(-1).unsqueeze(-1).float()
    c = torch.cat((b,c), dim = 1)
    candlestick_ohlc(a, c.squeeze(), colorup = "green", colordown = "red")
    
    
    
    a.plot(net_out[idx][:,0], color = "red")
    a.plot(net_out[idx][:,1], color = "blue")
    a.plot(net_out[idx][:,2], color = "red")
    a.plot(given_data[3].cpu().detach().numpy()[idx], label = "target", color = "orange")

t.train()    
plt.ion()

fig.show()
fig.canvas.draw()

resources :

temporal-fusion-transformer's People

Contributors

liammaclean216 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

temporal-fusion-transformer's Issues

mpl_finance deprecated

Should be:

mpl_finance is deprecated.

Should be:
pip install mplfinance

And:

from mplfinance.original_flavor import candlestick_ohlc

Error while running DownloadData.ipynb

when running gen_csv_year(2020, "ETHUSDT", interval = "5m")
After downloading the csv, it raise error:

Traceback (most recent call last):
  File "E:/DownloadData.py", line 76, in <module>
    gen_csv_year(2020, "ETHUSDT", interval = "5m")
  File "E:/DownloadData.py", line 71, in gen_csv_year
    gen_csv(datetime(year, 1, 1), "{}_{}".format(year, symbol), datetime(year + 1, 1, 1), symbol, interval)
  File "E:/DownloadData.py", line 65, in gen_csv
    for i in d:
  File "E:/DownloadData.py", line 50, in get_data
    yield pd.DataFrame.from_dict(data)[[0, 1, 2, 3, 4, 5]]
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\frame.py", line 2806, in __getitem__
    indexer = self.loc._get_listlike_indexer(key, axis=1, raise_missing=True)[1]
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\indexing.py", line 1552, in _get_listlike_indexer
    keyarr, indexer, o._get_axis_number(axis), raise_missing=raise_missing
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\indexing.py", line 1639, in _validate_read_indexer
    raise KeyError(f"None of [{key}] are in the [{axis_name}]")
KeyError: "None of [Int64Index([0, 1, 2, 3, 4, 5], dtype='int64')] are in the [columns]"

=========================================================================
I Changed 'yield pd.DataFrame.from_dict(data)[[0,1,2,3,4,5]]' to 'yield pd.DataFrame.from_dict(data).iloc[[0, 1, 2, 3, 4, 5]]'
then fixed the error above, but it still can't run successfully
with the error below and I don't know how to deal with it :(

Traceback (most recent call last):
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\indexing.py", line 2110, in _get_list_axis
    return self.obj._take_with_is_copy(key, axis=axis)
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\generic.py", line 3409, in _take_with_is_copy
    result = self.take(indices=indices, axis=axis, **kwargs)
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\generic.py", line 3395, in take
    indices, axis=self._get_block_manager_axis(axis), verify=True
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\internals\managers.py", line 1386, in take
    indexer = maybe_convert_indices(indexer, n)
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\indexers.py", line 212, in maybe_convert_indices
    raise IndexError("indices are out-of-bounds")
IndexError: indices are out-of-bounds

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "E:/DownloadData.py", line 86, in <module>
    gen_csv_year(2018, "ETHUSDT", interval = "5m")
  File "E:/DownloadData.py", line 81, in gen_csv_year
    gen_csv(datetime(year, 1, 1), "{}_{}".format(year, symbol), datetime(year + 1, 1, 1), symbol, interval)
  File "E:/DownloadData.py", line 75, in gen_csv
    for i in d:
  File "E:/DownloadData.py", line 59, in get_data
    print(pd.DataFrame.from_dict(data).iloc[[0, 1, 2, 3, 4, 5]])
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\indexing.py", line 1767, in __getitem__
    return self._getitem_axis(maybe_callable, axis=axis)
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\indexing.py", line 2128, in _getitem_axis
    return self._get_list_axis(key, axis=axis)
  File "E:\anaconda\envs\CondaEnvir\lib\site-packages\pandas\core\indexing.py", line 2113, in _get_list_axis
    raise IndexError("positional indexers are out-of-bounds")
IndexError: positional indexers are out-of-bounds

Would appreciate a lot if someone could tell me how to fix it : )

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.