NGC | Catalog
CatalogModelsTFT PyT checkpoint (Base, AMP, Traffic)

TFT PyT checkpoint (Base, AMP, Traffic)

For downloads and more information, please view on a desktop device.
Logo for TFT PyT checkpoint (Base, AMP, Traffic)

Description

TFT Base PyTorch checkpoint trained with AMP on Traffic dataset

Publisher

NVIDIA Deep Learning Examples

Latest Version

22.11.0_amp

Modified

April 6, 2023

Size

6.73 MB

Model Overview

Temporal Fusion Transformer is a state-of-the-art architecture for interpretable, multi-horizon time-series prediction.

Model Architecture

The TFT model is a hybrid architecture joining LSTM encoding of time series and interpretability of transformer attention layers. Prediction is based on three types of variables: static (constant for a given time series), known (known in advance for whole history and future), observed (known only for historical data). All these variables come in two flavors: categorical, and continuous. In addition to historical data, we feed the model with historical values of time series. All variables are embedded in high-dimensional space by learning an embedding vector. Categorical variables embeddings are learned in the classical sense of embedding discrete values. The model learns a single vector for each continuous variable, which is then scaled by this variable's value for further processing. The next step is to filter variables through the Variable Selection Network (VSN), which assigns weights to the inputs in accordance with their relevance to the prediction. Static variables are used as a context for variable selection of other variables and as an initial state of LSTM encoders. After encoding, variables are passed to multi-head attention layers (decoder), which produce the final prediction. Whole architecture is interwoven with residual connections with gating mechanisms that allow the architecture to adapt to various problems by skipping some parts of it. For the sake of explainability, heads of self-attention layers share value matrices. This allows interpreting self-attention as an ensemble of models predicting different temporal patterns over the same feature set. The other feature that helps us understand the model is VSN activations, which tells us how relevant the given feature is to the prediction. image source: https://arxiv.org/abs/1912.09363

Training

This model was trained using script available on NGC and in GitHub repo.

Dataset

The following datasets were used to train this model:

  • PEMS-SF - 15 months worth of daily data (440 daily records) that describes the occupancy rate, between 0 and 1, of different car lanes of the San Francisco bay area freeways across time.

Performance

Performance numbers for this model are available in NGC.

References

License

This model was trained using open-source software available in Deep Learning Examples repository. For terms of use, please refer to the license of the script and the datasets the model was derived from.