Geneformer generates a dense representation of a sc-RNA cell by learning co-expression patterns within single cells. Geneformer is a tabular count model trained on sc-RNA from the Chan Zuckerberg Cell x Gene census. Geneformer computes a complete embedding for each cell over the top 1024 expressed genes. The embeddings are used as features for a variety of predictive tasks. This model is ready for both commercial and academic use.
Architecture Type: Bidirectional Encoder Representations from Transformers (BERT)
Network Architecture: Geneformer
Input Type(s): Number (Row represents cell, containing gene names and single cell expression counts)
Input Format(s): Array AnnData
Input Parameters: 1D
Output Type(s): Vector (Dense Embedding Predictions)embeddings.
Output Format: NumPy
Output Parameters: 1D
Other Properties Related to Output: Numeric floating point vector (fp16, bf16, or fp32); geneformer-10M-240530 outputs 256 dimensional embeddings; geneformer-106M-240530 outputs 768 dimensional embeddings
Runtime Engine(s):
Supported Hardware Microarchitecture Compatibility:
[Preferred/Supported] Operating System(s):
Single cell expression counts from CELLxGENE Census used for the direct download of data matching similar criteria to those described in the geneformer publication. limiting cell data to organism=”Homo sapiens”, with a non “na” suspension_type, is_primary_data=True, and disease=”normal” to limit to non-diseased tissues that are also the primary data source per cell to make sure that cells are only included once in the download. We tracked metadata including “assay”, “sex”, “development_stage”, “tissue_general”, “dataset_id” and “self_reported_ethnicity”. The metadata “assay”, “tissue_general”, and “dataset_id” were used to construct dataset splits into train, validation, and test sets.
The training set represented 99% of the downloaded cells. We partitioned the data by dataset_id into a train set (99%) and a hold-out set (1%), to make sure that the hold-out datasets were independently collected single cell experiments, which helps evaluate generalizability to new future datasets.
In this training split, we made sure that all “assay” and “tissue_general” labels were present in the training set so that our model would have maximal visibility into different tissues and assay biases.
The 1% hold-out evaluation set was split further into a validation and test set. This final split was mostly done randomly by cell; however, we set aside a full dataset into the test split so that we could evaluate performance after training on a completely unseen dataset, including when monitoring the validation loss during training.
Link: Datasets downloaded from CZ CELLxGENE Discover - Cellular Visualization Tool (cziscience.com)
** Data Collection Method by dataset
Labeling Method by dataset
Properties (Quantity, Dataset Descriptions, Sensor(s)):
23.64 million non-diseased and human-derived single cells were chosen from the CZI CELLxGENE census, which is characterized as follows:
Dataset was derived from a limited number of public sources where methods and protocols may not represent sufficiently diverse sources to capture the full scope of gene expression.
Adamson et al 2016 PERTURB-seq dataset, accessed by Harvard dataverse.
Link: adamson.zip - Harvard Dataverse
Data Collection Method by dataset
Labeling Method by dataset
Properties (Quantity, Dataset Descriptions, Sensor(s)): There are ~20k single cells, half of which represent unperturbed control samples, and the other half which contain an additional datatable containing the CRISPR knock-out targets for each cell.
Link: CZ CELLxGENE Discover - Cellular Visualization Tool (cziscience.com)
Data Collection Method by dataset
Labeling Method by dataset
Properties (Quantity, Dataset Descriptions, Sensor(s)):
dataset_id
with any cell in the training data described previously.Engine: BioNeMo, NeMo
Test Hardware:
This checkpoint was trained for approximately 11 epochs through the CELLxGENE split. Training was performed on 8 servers with 8 A100 GPUs each for a total of 115430 steps of per-gpu micro batch size 32 and global batch size of 2048. Training took a total of 1 day, 20 hours and 19 minutes of wallclock time. As can be seen in the following image, training and validation curves both decreased fairly smoothly throughout the course of training. In fact validation (blue) and training (orange) loss were both still decreasing at the end of 11 epochs through the dataset. The model could likely be trained for more epochs without overfitting.
This checkpoint was trained for approximately 11 epochs through the CELLxGENE split. Training was performed on 16 servers with 8 A100 GPUs each for a total of 115430 steps of per-gpu micro batch size 16 and global batch size of 2048. Training took a total of 3 days, 18 hours and 55 minutes of wallclock time. As can be seen in the following image, training and validation curves both decreased fairly smoothly throughout the course of training. In fact validation (blue) and training (orange) loss were both still decreasing at the end of 11 epochs through the dataset. The model could likely be trained for more epochs without overfitting.
Additionally, validation loss decreased both faster and continued to decrease at the same improved rate throughout training in the 106M parameter model (red) as compared to the 10M parameter model (blue). It would be interesting to test even larger models to see if we continue to observe improved performance in larger models.
The following describes the bert MLM token loss. Like in the original BERT paper, and the geneformer paper, 15% of all tokens are included in the loss. Of the included tokens, 80% are "[MASK]"
token, 10% are a random gene token, and 10% are the correct output token. The token loss in the following table is the mean cross entropy loss of the 15% of tokens included in the loss mask averaged across cells. As a baseline geneformer was downloaded from the ctheodoris/Geneformer page on hugging face on 2024/05/13 and applied to the same masking/unmasking problem on this dataset. The held-out test
datset from our training splits described previously was used, and it should be noted that some of these cells may have been involved in training the baseline geneformer. Since the baseline performed slightly worse than our new checkpoints, and our goal was an equivalent or better model checkpoint, this possibility was not explored further.
Model Description | Token Loss (lower is better) |
---|---|
Baseline geneformer | 3.35 |
geneformer-10M-240530 | 2.79 |
geneformer-106M-240530 | 2.50 |
The 106M parameter variant of Geneformer achieves over 50 TFLOPS per GPU during training. This is consistent whether trained with 1 or 8 A100s.
Geneformer is available through Apache 2.0 license.
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the Model Card++ Explainability, Bias, Safety & Security, and Privacy Subcards [Insert Link to Model Card++ here]. Please report security vulnerabilities or NVIDIA AI Concerns here.