For downloads and more information, please view on a desktop device.

TFT Base PyTorch checkpoint trained with AMP on Electricity dataset

NVIDIA Deep Learning Examples

22.11.0_amp

April 6, 2023

6.44 MB

Temporal Fusion Transformer is a state-of-the-art architecture for interpretable, multi-horizon time-series prediction.

The TFT model is a hybrid architecture joining LSTM encoding of time series and interpretability of transformer attention layers. Prediction is based on three types of variables: static (constant for a given time series), known (known in advance for whole history and future), observed (known only for historical data). All these variables come in two flavors: categorical, and continuous. In addition to historical data, we feed the model with historical values of time series. All variables are embedded in high-dimensional space by learning an embedding vector. Categorical variables embeddings are learned in the classical sense of embedding discrete values. The model learns a single vector for each continuous variable, which is then scaled by this variable's value for further processing. The next step is to filter variables through the Variable Selection Network (VSN), which assigns weights to the inputs in accordance with their relevance to the prediction. Static variables are used as a context for variable selection of other variables and as an initial state of LSTM encoders.
After encoding, variables are passed to multi-head attention layers (decoder), which produce the final prediction. Whole architecture is interwoven with residual connections with gating mechanisms that allow the architecture to adapt to various problems by skipping some parts of it.
For the sake of explainability, heads of self-attention layers share value matrices. This allows interpreting self-attention as an ensemble of models predicting different temporal patterns over the same feature set. The other feature that helps us understand the model is VSN activations, which tells us how relevant the given feature is to the prediction.
*image source: https://arxiv.org/abs/1912.09363*

This model was trained using script available on NGC and in GitHub repo.

The following datasets were used to train this model:

- Electricity - Dataset containing electricity consumption of 370 points/clients.

Performance numbers for this model are available in NGC.

This model was trained using open-source software available in Deep Learning Examples repository. For terms of use, please refer to the license of the script and the datasets the model was derived from.