NVIDIA
NVIDIA
SE(3)-Transformers for PyTorch
Resource
NVIDIA
NVIDIA
SE(3)-Transformers for PyTorch

A Graph Neural Network using a variant of self-attention for 3D points and graphs processing.

The following sections provide greater details of the dataset, running training and inference, and the training results.

Scripts and sample code

In the root directory, the most important files are:

  • Dockerfile: container with the basic set of dependencies to run SE(3)-Transformers
  • requirements.txt: set of extra requirements to run SE(3)-Transformers
  • se3_transformer/data_loading/qm9.py: QM9 data loading and preprocessing, as well as bases precomputation
  • se3_transformer/model/layers/: directory containing model architecture layers
  • se3_transformer/model/transformer.py: main Transformer module
  • se3_transformer/model/basis.py: logic for computing bases matrices
  • se3_transformer/runtime/training.py: training script, to be run as a python module
  • se3_transformer/runtime/inference.py: inference script, to be run as a python module
  • se3_transformer/runtime/metrics.py: MAE metric with support for multi-GPU synchronization
  • se3_transformer/runtime/loggers.py: DLLogger and W&B loggers

Parameters

The complete list of the available parameters for the training.py script contains:

General

  • --epochs: Number of training epochs (default: 100 for single-GPU)
  • --batch_size: Batch size (default: 240)
  • --seed: Set a seed globally (default: None)
  • --num_workers: Number of dataloading workers (default: 8)
  • --amp: Use Automatic Mixed Precision (default false)
  • --gradient_clip: Clipping of the gradient norms (default: None)
  • --accumulate_grad_batches: Gradient accumulation (default: 1)
  • --ckpt_interval: Save a checkpoint every N epochs (default: -1)
  • --eval_interval: Do an evaluation round every N epochs (default: 20)
  • --silent: Minimize stdout output (default: false)

Paths

  • --data_dir: Directory where the data is located or should be downloaded (default: ./data)
  • --log_dir: Directory where the results logs should be saved (default: /results)
  • --save_ckpt_path: File where the checkpoint should be saved (default: None)
  • --load_ckpt_path: File of the checkpoint to be loaded (default: None)

Optimizer

  • --optimizer: Optimizer to use (default: adam)
  • --learning_rate: Learning rate to use (default: 0.002 for single-GPU)
  • --momentum: Momentum to use (default: 0.9)
  • --weight_decay: Weight decay to use (default: 0.1)

QM9 dataset

  • --task: Regression task to train on (default: homo)
  • --precompute_bases: Precompute bases at the beginning of the script during dataset initialization, instead of computing them at the beginning of each forward pass (default: false)

Model architecture

  • --num_layers: Number of stacked Transformer layers (default: 7)
  • --num_heads: Number of heads in self-attention (default: 8)
  • --channels_div: Channels division before feeding to attention layer (default: 2)
  • --pooling: Type of graph pooling (default: max)
  • --norm: Apply a normalization layer after each attention block (default: false)
  • --use_layer_norm: Apply layer normalization between MLP layers (default: false)
  • --low_memory: If true, will use ops that are slower but use less memory (default: false)
  • --num_degrees: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: 4)
  • --num_channels: Number of channels for the hidden features (default: 32)

Command-line options

To show the full list of available options and their descriptions, use the -h or --help command-line option, for example: python -m se3_transformer.runtime.training --help.

Dataset guidelines

Demo dataset

The SE(3)-Transformer was trained on the QM9 dataset.

The QM9 dataset is hosted on DGL servers and downloaded (38MB) automatically when needed. By default, it is stored in the ./data directory, but this location can be changed with the --data_dir argument.

The dataset is saved as a qm9_edge.npz file and converted to DGL graphs at runtime.

As input features, we use:

  • Node features (6D):
    • One-hot-encoded atom type (5D) (atom types: H, C, N, O, F)
    • Number of protons of each atom (1D)
  • Edge features: one-hot-encoded bond type (4D) (bond types: single, double, triple, aromatic)
  • The relative positions between adjacent nodes (atoms)

Custom datasets

To use this network on a new dataset, you can extend the DataModule class present in se3_transformer/data_loading/data_module.py.

Your custom collate function should return a tuple with:

  • A (batched) DGLGraph object
  • A dictionary of node features ({'{degree}': tensor})
  • A dictionary of edge features ({'{degree}': tensor})
  • (Optional) Precomputed bases as a dictionary
  • Labels as a tensor

You can then modify the training.py and inference.py scripts to use your new data module.

Training process

The training script is se3_transformer/runtime/training.py, to be run as a module: python -m se3_transformer.runtime.training.

Logs

By default, the resulting logs are stored in /results/. This can be changed with --log_dir.

You can connect your existing Weights & Biases account by setting the WANDB_API_KEY environment variable, and enabling the --wandb flag. If no API key is set, --wandb will log the run anonymously to Weights & Biases.

Checkpoints

The argument --save_ckpt_path can be set to the path of the file where the checkpoints should be saved. --ckpt_interval can also be set to the interval (in the number of epochs) between checkpoints.

Evaluation

The evaluation metric is the Mean Absolute Error (MAE).

--eval_interval can be set to the interval (in the number of epochs) between evaluation rounds. By default, an evaluation round is performed after each epoch.

Automatic Mixed Precision

To enable Mixed Precision training, add the --amp flag.

Multi-GPU and multi-node

The training script supports the PyTorch elastic launcher to run on multiple GPUs or nodes. Refer to the official documentation.

For example, to train on all available GPUs with AMP:

python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --module se3_transformer.runtime.training --amp

Inference process

Inference can be run by using the se3_transformer.runtime.inference python module.

The inference script is se3_transformer/runtime/inference.py, to be run as a module: python -m se3_transformer.runtime.inference. It requires a pre-trained model checkpoint (to be passed as --load_ckpt_path).

NVIDIA uses cookies to improve your experience on our web site. We and our third-party partners also use cookies and other tools to collect and record information you provide as well as information about your interactions with our websites for performance improvement, analytics, and to assist in marketing efforts. By clicking "Accept All", you consent to our use of cookies and other tools as described in our Cookie Policy. You can manage your cookie settings by clicking on "Manage Settings." By continuing to use this site or by clicking one of the buttons below, you agree to our Terms of Service (which contains important waivers). Please see our Privacy Policy for more information on our privacy practices.