NGC | Catalog
Logo for JAX
Features
Description
JAX is a framework for high-performance numerical computing and machine learning research. It includes Numpy-like APIs, automatic differentiation, XLA acceleration and simple primitives for scaling across GPUs and supports an ecosystem of libraries.
Publisher
Google
Latest Tag
23.10-paxml-py3
Modified
April 21, 2024
Compressed Size
7.14 GB
Multinode Support
Yes
Multi-Arch Support
Yes
23.10-paxml-py3 (Latest) Security Scan Results

Linux / amd64

Sorry, your browser does not support inline SVG.

Linux / arm64

Sorry, your browser does not support inline SVG.

JAX is a framework for high-performance numerical computing and machine learning research. It includes Numpy-like APIs, automatic differentiation, XLA acceleration and simple primitives for scaling across GPUs.

The JAX NGC Container comes with all dependencies included, providing an easy place to start developing applications in areas such as NLP, Computer Vision, Multimodality, physics-based simulations, reinforcement learning, drug discovery, and neural rendering.

The JAX NGC Container is optimized for GPU acceleration, and contains a validated set of libraries that enable and optimize GPU performance. This container may also include modifications to the JAX source code in order to maximize performance and compatibility. This container also includes software for accelerating ETL (DALI and training (cuDNN, NCCL). 

For building neural networks, the JAX NGC Container includes Flax, a neural network library with support for common deep learning models, layers and optimizers. We also include a container for Paxml, a framework for training LLMs such as GPT, and a container for T5x, a framework for training T5 and other Flax-based models. You can use the JAX, Paxml, or T5x containers for your deep learning workloads or install your own favorite libraries on top of them. 

Prerequisites

Using the JAX NGC Container requires the host system to have the following installed:

For supported versions, see the Framework Containers Support Matrix and the NVIDIA Container Toolkit Documentation.

No other installation, compilation, or dependency management is required. It is not necessary to install the NVIDIA CUDA Toolkit.

Running JAX

To run a container, issue the appropriate command as explained in the Running A Container chapter in the NVIDIA Containers For Deep Learning Frameworks User's Guide and specify the registry, repository, and tags. For more information about using NGC, refer to the NGC Container User Guide.

The following command assumes you want to run the JAX container interactively, where 23.08 is the container version. For example, 23.08 for August 2023 release:

docker run --gpus all -it --rm nvcr.io/nvidia/jax:23.08-py3

To pull data and model descriptions from locations outside the container for use by JAX or save results to locations outside the container, mount one or more host directories as Docker data volumes.

See /workspace/README.md for information on getting started and customizing your JAX image.

If you use multiprocessing for multi-threaded data loaders, the default shared memory segment size that the container runs with might not be enough. To increase the shared memory size, issue one of the following commands:

--shm-size=1g
--ulimit memlock=-1

Note: In order to share data between ranks, NCCL may require shared system memory for IPC and pinned (page-locked) system memory resources. The operating system's limits on these resources may need to be increased accordingly. In particular, Docker containers default to limited shared and pinned memory resources. When using NCCL inside a container, it is recommended that you increase these resources.

Running JAX in multi-node, multi-GPU

One of the key features of JAX is its easy scaling primitives for running JAX processes across multiple accelerators. The JAX distributed system allows JAX processes to discover each other and share topology information, perform health checks, ensure that all processes shut down if any process dies, and can be used for distributed checkpointing.

For information on how to set up a cluster and launch JAX processes, please refer to JAX readthedocs. For HPC cluster environments with a Slurm or OpenMPI scheduler, the jax.distributed.initialize() API will automatically detect all available JAX processes.

What Is In This Container?

This container image contains the complete source of the NVIDIA version of JAX in /opt/jax. It is prebuilt and installed as a system Python module. Visit JAX's readthedocs page to learn more about JAX.

The NVIDIA JAX Container is optimized for use with NVIDIA GPUs, and contains the following software for GPU acceleration:

The software stack in this container has been validated for compatibility, and does not require any additional installation or compilation from the end user. This container can help accelerate your deep learning workflow from end to end.

ETL

NVIDIA Data Loading Library (DALI) is designed to accelerate data loading and preprocessing pipelines for deep learning applications by offloading them to the GPU. DALI primarily focuses on building data preprocessing pipelines for image, video, and audio data. These pipelines are typically complex and include multiple stages, leading to bottlenecks when run on CPU. Use this container to get started on accelerating data loading with DALI.

Training

NVIDIA CUDA Deep Neural Network Library (cuDNN) is a GPU-accelerated library of primitives for deep neural networks. cuDNN provides highly tuned implementations for standard routines such as forward and backward convolution, pooling, normalization, and activation layers. The version of JAX in this container is precompiled with cuDNN support, and does not require any additional configuration.

NVIDIA Collective Communications Library (NCCL) implements multi-GPU and multi-node communication primitives for NVIDIA GPUs and Networking that take into account system and network topology. NCCL is integrated with JAX to accelerate training on multi-GPU and multi-node systems. In particular, NCCL provides the default all-reduce algorithm for the Mirrored and MultiWorkerMirrored distributed training strategies.

Suggested Reading

For the latest Release Notes, see the JAX Release Notes.

For a full list of the supported software and specific versions that come packaged with this framework based on the container image, see the Frameworks Support Matrix.

For more information about JAX, including tutorials, documentation, and examples, see:

JAX on Public Clouds

Security CVEs

To review known CVEs on this image, refer to the Security Scanning tab on this page.

License

By pulling and using the container, you accept the terms and conditions of this End User License Agreement.