Predicting cellular responses to perturbation across diverse contexts with STATE
TL;DR Summary
STATE, a transformer model trained on over 100 million cells, predicts perturbation responses across diverse cellular contexts with 30% improved effect discrimination and accurate gene identification, enabling generalization to unseen backgrounds and advancing scalable virtual ce
Abstract
Predicting cellular responses to perturbation across diverse contexts with State Anonymous Author(s) Affiliation Address email Abstract 1 Cellular responses to perturbations are a cornerstone for understanding biological mechanisms and 2 selecting drug targets. While machine learning models offer tremendous potential for predicting perturbation 3 effects, they currently struggle to generalize to unobserved cellular contexts. Here, we introduce State , 4 a transformer model that predicts perturbation effects while accounting for cellular heterogeneity within 5 and across experiments. State predicts perturbation effects across sets of cells and is trained using gene 6 expression data from over 100 million perturbed cells. State improved discrimination of effects on large 7 datasets by more than 30% and identified differentially expressed genes across genetic, signaling and chemical 8 perturbations with significantly improved accuracy. Using its cell embedding trained on observational data 9 from 167 million cells, State identified strong perturbations in novel cellular contexts where no perturbations 10 were observed during training. Overall, the perfo
Mind Map
In-depth Reading
English Analysis
1. Bibliographic Information
1.1. Title
Predicting cellular responses to perturbation across diverse contexts with STATE
1.2. Authors
The authors are listed as "Anonymous Author(s)" in the provided paper draft, which is common for submissions to conferences with a double-blind review process.
1.3. Journal/Conference
The paper is "Submitted to 39th Conference on Neural Information Processing Systems (NeurIPS 2025)." NeurIPS is one of the most prestigious and influential conferences in the field of artificial intelligence and machine learning, indicating a high-profile venue for research dissemination.
1.4. Publication Year
The paper was submitted to NeurIPS 2025, suggesting a target publication year of 2025. However, it is currently a preprint as indicated by its submission status and the anonymous authors.
1.5. Abstract
Cellular responses to perturbations are critical for understanding biological mechanisms and identifying drug targets. Existing machine learning models for predicting these responses struggle to generalize across diverse cellular contexts. This paper introduces STATE, a transformer-based model designed to predict perturbation effects while explicitly accounting for cellular heterogeneity both within and across experiments. STATE operates by predicting perturbation effects across sets of cells and was trained using gene expression data from over 100 million perturbed cells. The model demonstrated significant improvements, increasing effect discrimination by more than 30% on large datasets and achieving substantially higher accuracy in identifying differentially expressed genes across genetic, signaling, and chemical perturbations. Leveraging cell embeddings trained on a vast observational dataset of 167 million cells, STATE successfully predicts strong perturbations in novel cellular contexts that were not observed during its training. The authors suggest that STATE's performance and flexibility pave the way for the development of scalable virtual cell models for various biological applications.
1.6. Original Source Link
The original source link provided is /files/papers/690878ad1ccaadf40a4344dd/paper.pdf. Given the context of "Submitted to 39th Conference on Neural Information Processing Systems (NeurIPS 2025)" and anonymous authors, this link points to a preprint version of the paper.
2. Executive Summary
2.1. Background & Motivation
The core problem STATE aims to solve is the accurate prediction of cellular responses to various perturbations (e.g., genetic, chemical, signaling) across diverse and often unseen cellular contexts. Understanding these responses is fundamental for basic biological research, disease modeling, and crucial for rational drug target selection in pharmaceutical development.
Despite the rapid growth in the size and scope of perturbation datasets, current machine learning models face significant challenges:
-
Generalization to Unobserved Contexts: Many models struggle to accurately predict how cells will respond in cellular environments or with perturbations not encountered during training.
-
Destructive Nature of Single-Cell Assays: Single-cell RNA sequencing (scRNA-seq) measurements destroy the cell, preventing observation of its pre-perturbation state. This makes it difficult to infer individual cell responses directly.
-
Cellular Heterogeneity: Cells within a population, even under identical conditions, exhibit considerable biological variation. Existing models often oversimplify this
within-population heterogeneityor fail to distinguish it from true perturbation signals, especially when perturbation effects are subtle. -
Data Scale Limitations: While data is growing, many models are still in a
data-poor regimefor effectively learning complex perturbation effects that generalize.The paper's entry point is to explicitly account for
cellular heterogeneityandcross-dataset variabilityusing a novel transformer-based architecture and a vast amount of single-cell data. It proposes to move beyond single-cell prediction towards modeling distributions of cell responses.
2.2. Main Contributions / Findings
The paper makes several significant contributions:
-
Introduction of
STATE: A novel multi-scale, transformer-based machine learning architecture designed for predicting cellular responses to perturbations.STATEconsists of two complementary modules:State Transition model (ST): A transformer that modelsperturbation-induced transformationsacross sets of cells, explicitly capturingbiological heterogeneityand avoiding reliance on explicit distributional assumptions.State Embedding model (SE): A pre-trained encoder-decoder transformer that generates robustcell embeddingsby learning gene expression variation across diverse observational datasets, optimized for detecting perturbation effects and reducingtechnical variation.
-
Scalable Training with Large Datasets:
STATEleverages an unprecedented scale of data, utilizing 167 million cells of observational data to trainSEand over 100 million perturbed cells to trainST. -
Superior Performance and Generalization:
STATEachieved an absolute improvement of 54% and 29% inperturbation discriminationonTahoe-100MandParse-PBMCdatasets, respectively.- It significantly improved
differential gene expression (DEG)identification, being twice as good as the next best baseline onTahoe-100Mand 43% better onParse-PBMC. - The model accurately ranked perturbations by their relative effect sizes, showing Spearman correlations 53% higher on
Parse-PBMC, 22% higher onReplogleNadig, and 70% higher onTahoe-100Mcompared to baselines. STATEdemonstrated strongcontext generalizationandzero-shot predictioncapabilities, successfully predicting strong perturbations in novel cellular contexts not seen during training.
-
Leveraging Data Scale: The performance gains of
STATEwere most pronounced with increasing data scale, suggesting its architecture is better equipped to utilize large datasets compared to existing models. -
Theoretical Foundations for Optimal Transport: The paper provides theoretical analysis suggesting that
ST's solution family covers theoptimal transport mapbetween control and perturbed cell distributions asymptotically, hinting at its capacity to learn fundamental cellular transformations.These findings position
STATEas a crucial step towards developing scalablevirtual cell models, which could accelerate causal target discovery and deepen the understanding of cellular function and disease.
3. Prerequisite Knowledge & Related Work
3.1. Foundational Concepts
To understand the STATE model and its significance, a foundational understanding of several biological and machine learning concepts is essential.
-
Cellular Perturbations: In biology, a
perturbationrefers to any intentional modification or intervention applied to a biological system (like a cell) to observe its response. This can include:- Genetic perturbations: Manipulating gene expression (e.g., using
CRISPRto knock out or activate specific genes). - Chemical perturbations: Treating cells with drugs or small molecules.
- Signaling perturbations: Modulating cellular signaling pathways (e.g., with cytokines).
Studying these responses is crucial for understanding
biological mechanismsand identifyingdrug targets.
- Genetic perturbations: Manipulating gene expression (e.g., using
-
Gene Expression Data / Single-Cell RNA Sequencing (scRNA-seq):
Gene expressionis the process by which information from a gene is used in the synthesis of a functional gene product, such as a protein. Cells vary widely in which genes they express and at what levels.scRNA-seqis a high-throughput technology that measures thegene expression profileof individual cells. This generates data where each cell is represented by a vector ofgene countsorexpression levels(e.g.,log-normalized countsas used in the paper, which involves applying a logarithmic transformation to the counts to stabilize variance and make distributions more symmetric).- A key challenge with
scRNA-seqis itsdestructive nature: measuring a cell's transcriptome destroys it, meaning we cannot observe the same cell's state before and after a perturbation.
-
Machine Learning Models for Perturbation Prediction: The goal is to develop models that can predict how the
gene expression profileof a cell (or a population of cells) will change in response to a givenperturbation. This involves learning complex relationships betweenbasal cell states,perturbation types, andresulting transcriptomic changes. -
Transformers and Self-Attention:
Transformersare a neural network architecture, originally developed for natural language processing, that have become highly influential due to theirself-attention mechanism.Self-attention(orintra-attention) allows the model to weigh the importance of different parts of an input sequence (or set) relative to each other when processing each part. For a set of cell representations, self-attention means that when processing one cell's data, the model can consider and weigh the information from all other cells in the set. This is particularly powerful for capturing relationships and heterogeneity within aset of cells. The generalAttentionformula is: $ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $ where:- (Query), (Key), (Value) are matrices derived from the input embeddings.
- calculates the
attention scores(how much each element relates to others). - is a scaling factor, typically the square root of the dimension of the keys, to prevent very large attention scores.
softmaxnormalizes these scores into a probability distribution.- The result is a weighted sum of the Value vectors, allowing the model to focus on relevant information.
-
Cell Embeddings: An
embeddingis a lower-dimensional, dense vector representation of an entity (in this case, a cell) that captures its essential features.Cell embeddingsaim to represent a cell'stranscriptomic state(its unique gene expression profile) in a way that is robust to technical noise and highlights biological similarities or differences. These embeddings can then be used as input for downstream tasks or as a more abstract space to model cellular processes. -
Differential Expression Analysis (DEG): This statistical analysis identifies genes whose expression levels change significantly between two or more biological conditions (e.g., perturbed vs. unperturbed cells).
Differentially Expressed Genesare crucial indicators of a cell's response to a perturbation. -
Maximum Mean Discrepancy (MMD):
MMDis a statistical test used to determine if two samples are drawn from the same distribution. It measures the distance between themean embeddingsof two distributions in aReproducing Kernel Hilbert Space (RKHS). Intuitively, if the average function value (embedding) for samples from distribution A is very different from the average function value for samples from distribution B, then the distributions are likely different. The paper uses theenergy distance kernelwhich makesMMDequivalent to theenergy distance. A key property of theenergy distanceis that it is zero if and only if the two distributions are identical. -
Optimal Transport (OT):
Optimal Transportis a mathematical framework for comparing and transforming probability distributions. It seeks the most efficient way to "move" mass from one distribution to another, minimizing a given "cost" (e.g., Euclidean distance between points). In the context of cells,Optimal Transportcould theoretically map a distribution of unperturbed cell states to a distribution of perturbed cell states.
3.2. Previous Works
The paper discusses several categories of prior approaches and specific models for perturbation prediction:
-
Mapping Perturbed to Unperturbed Cells: Some earlier methods (e.g., [10-12]) assumed
within-population heterogeneitywas negligible and tried to map perturbed cells to randomly selected unperturbed cells with sharedcovariates(like cell type).- Limitation: These models often fail when
perturbation effectsare subtle, asunperturbed population heterogeneitycan be larger than theperturbation signalitself, making accurate mapping difficult.
- Limitation: These models often fail when
-
Distribution-Based Models: Other models treat cell populations as
distributions, aiming to learndata-generating distributionsor disentangle labeled and unlabeled sources of variation (e.g., [8, 19-26]).- Limitation: In practice, the paper notes that these models often do not significantly outperform methods that don't explicitly model
distributional structure[14]. Specific examples mentioned:scVI[25]: Adeep generative modelthat aims to modelgene expression distributionswhile accounting fortechnical noiseandbatch effects.CPA(Compositional Perturbation Autoencoder) [8, 47]: Anautoencoder-based modelthat learns acompositional latent spaceto capture additive effects ofperturbation,dosage, andcell type.
- Limitation: In practice, the paper notes that these models often do not significantly outperform methods that don't explicitly model
-
Optimal Transport-based Methods: These methods attempt to map
unperturbed populationstoperturbed populations(e.g., [9, 27-29]).- Limitation: Their applicability has been limited by strong
assumptionsandpoor scalability. For instance,CeT[9] usesInput Convex Neural Networks (ICNNs)to parameterize convex potentials for learningoptimal transport maps.
- Limitation: Their applicability has been limited by strong
-
Foundation Models for Single-Cell Omics: Recent advancements include large-scale
transformer-based modelstrained on vast amounts of single-cell data, aiming for broad generalization:scGPT[11]: Atransformer-based foundation modelthat leveragesgenerative pretrainingon over 33 million cells for tasks likeperturbation prediction.GEARS[10]: Another model for predicting transcriptional outcomes ofnovel multi-gene perturbations.
3.3. Technological Evolution
The field of perturbation response prediction has evolved alongside advancements in single-cell technologies and deep learning.
-
Early Statistical Models: Initially, methods relied on simpler statistical models or linear regressions to identify
differentially expressed genes. -
Increased Data Scale: The advent of
high-throughput screeningtechnologies likepooled CRISPR perturbations(e.g.,Perturb-seq[1-6]) has dramatically increased the scale and complexity of availableperturbation datasets. -
Autoencoder-based Models: As
deep learninggained traction,autoencoders(likeCPAandscVI) were adapted to learnlatent representationsof cell states, aiming to capture underlying biological variations and predict perturbation effects in a lower-dimensional space. -
Foundation Models and Transformers: The most recent wave involves
transformer architecturesandfoundation models(scGPT,GEARS). These models are trained on massive datasets (millions to hundreds of millions of cells) to learngeneral-purpose cell embeddingsandbiological rules, hoping to achievezero-shot generalizationto unseen contexts.This paper's work,
STATE, represents the cutting edge of this evolution by:
- Utilizing
transformer architecturefor itsself-attentioncapabilities to modelcellular heterogeneityat the population level. - Integrating
pre-trained cell embeddings(SE) from an even larger observational dataset (167 million cells) to enhance robustness andtransferability. - Employing
Maximum Mean Discrepancy (MMD)to directly compare distributions of predicted and observed cell states, moving beyond single-cell reconstruction errors. - Providing a theoretical link to
Optimal Transport, suggesting a principled way to learn cellular transformations.
3.4. Differentiation Analysis
STATE differentiates itself from previous approaches primarily through its multi-scale architecture, explicit handling of heterogeneity, and massive data leverage:
-
Explicitly Modeling Heterogeneity (vs. Neglecting or Averaging):
- Unlike methods that assume
within-population heterogeneityis negligible or simply map to averageunperturbed cells(Perturbation Mean Baseline,Context Mean Baseline),STusesself-attentionover sets of cells. This allows it to explicitly model and learnresidual, unannotated heterogeneitywithin cell populations, which is crucial for subtle perturbation effects. - Compared to
pseudobulkmodels (which average responses),STretains more information aboutcell-to-cell variability.
- Unlike methods that assume
-
Robust Cell Embeddings for Transfer (vs. Context-Specific Representations):
- Many models operate directly on
gene expression profilesor learnembeddingsthat are specific to a dataset or context.STATEintroducesSE, apre-trained cell embedding modeltrained on 167 million observational cells, leveragingprotein language modelsfor robustgene representations. - These
SE embeddingsare designed to unifycell representationacross diverse datasets, makingSTATE's predictions morerobust to technical variationand significantly improvingcross-dataset transfer learningandzero-shot generalizationcompared to models trained onhighly variable genes (HVGs)or otherfoundation modelslikescGPT.
- Many models operate directly on
-
Distributional Learning with MMD (vs. Pointwise Prediction or Autoencoder Reconstruction):
- Instead of predicting individual cell responses or relying solely on
autoencoderreconstruction losses (CPA,scVI),STis trained usingMaximum Mean Discrepancy (MMD).MMDdirectly minimizes the distance between the distributions of predicted and observed cell sets. Thisdistributional perspectiveis more biologically relevant forpopulation-level responsesand handles thedestructive natureofscRNA-seqmore naturally.
- Instead of predicting individual cell responses or relying solely on
-
Scalability and Data Leverage:
-
STATEis explicitly designed to leveragegiga-scale datasets(100M perturbed, 167M observational cells). The paper demonstrates thatSTATE's performance benefits significantly from increased data scale, outperforming other models that seem to be in adata-poor regime. This suggestsSTATE's architecture is inherently more capable of extracting insights from massive biological datasets.In essence,
STATE's innovations lie in its integrated approach to handle complexcellular heterogeneity, generate universally applicablecell embeddings, learndistributional transformations, and scale efficiently togiga-cell datasets.
-
4. Methodology
4.1. Principles
The core idea behind STATE is to develop a multi-scale machine learning architecture that can predict how cells respond to perturbations. It achieves this by explicitly accounting for cellular heterogeneity—the natural variations among cells—both within a single experiment and across different experiments. The model is built on two complementary modules:
-
State Transition model (ST): This module is responsible for learning the actualperturbation effects. Instead of trying to predict the response of a single cell (which is problematic given thatscRNA-seqassays destroy cells),SToperates onsets of cells. It uses atransformerwithself-attentionto model how a distribution of control cells transforms into a distribution of perturbed cells. This set-based approach allows it to capture and predict changes incellular heterogeneitybeyond what can be explained by simplecovariateslike cell type. -
State Embedding model (SE): This module is designed to create robust and generalizablecell embeddings.SEis pre-trained on a massive amount ofobservational single-cell data(data from cells that were not intentionally perturbed) to learn representations that are less sensitive totechnical noiseand better capture underlyingbiological signals. Thesecell embeddingsthen serve as high-quality inputs for theSTmodule, enablingSTATEto transfer knowledge across different datasets and experimental conditions more effectively.The
multi-scaleaspect refers toSTATE's ability to operate ongene-level features,individual cell representations(viaSEembeddings), andcell population transformations(viaSTon sets of cells). By combining these,STATEaims to create avirtual cell modelthat is flexible, scalable, and capable of generalizing to novel contexts.
4.2. Core Methodology In-depth (Layer by Layer)
The STATE model is composed of the State Transition (ST) model and the State Embedding (SE) model. We will first describe the conceptual Data Generation Process and then detail each module.
4.2.1. Data Generation Process
The paper begins by framing the problem with a conceptual data generation process.
The observed log-normalized perturbed expression state of a cell, denoted as , is considered a random variable generated from an unobservable unperturbed state, . itself is a random variable representing the cell's underlying expression state, drawn from a basal cell distribution specific to a given set of covariates (e.g., cell line, batch condition).
The perturbation effect is modeled as: where:
-
: Represents the true effect caused by perturbation on the unperturbed cell state .
-
: Denotes experiment-specific
technical noise, assumed to be independent of .Since
single-cell transcriptomic measurementsdestroy the cell, is unobservable, making direct modeling of this equation infeasible.STATEtherefore shifts to adistributional view, operating on the observablebasal cell distributionto predict theperturbed state, denoted as .
This forms the basis of STATE's model:
In this approximation:
- : The
true effectof the perturbation is now considered in the context of the entirebasal population. - : Explicitly represents the
biological heterogeneityinherent in the baseline population. This heterogeneity was implicitly removed when sampling in the first equation but is made explicit here to reflect thedistributional view. - : Is a
distributional analogueof , allowing the model to predict theperturbed statebased onobservable population characteristicsrather than unobservable individual cell states.
4.2.2. State Transition Model (ST)
ST is the core model for learning perturbation effects across populations of cells. It uses a transformer architecture with self-attention across cells to predict perturbation effects between control and perturbed cell distributions.
A dataset of single-cell RNA-sequencing measurements is defined as:
Here:
- : Represents the
log-normalized expression vectorfor cell , where is the number of genes. - :
Perturbation labelfor cell , which can be a specific perturbation ID or acontrolstate (). - :
Biological contextorcell line labelfor cell . - : Optional
batch effect labelfor cell .
4.2.2.1. Formation of Cell Sets
Cells are grouped into sets based on their biological context (\ell), perturbation (), and batch labels ():
From these groups, fixed-size cell sets of size are formed. If a group has cells, it is chunked into cell sets , where indexes the sets. If is not divisible by , the remaining cells form a smaller set, which is then padded to size by sampling additional cells with replacement from itself.
4.2.2.2. Training on Cell Sets
During training, each perturbed cell set (where ) is paired with a corresponding control cell set. The control set is generated by a mapping function that samples control cells from the same cell line and batch .
This mapping function is crucial as it controls which sources of variation are explicitly accounted for. By conditioning on specific covariates, it reduces known sources of heterogeneity.
For a mini-batch of such set pairs, the following tensors are constructed:
- (observed perturbed cell sets).
- (corresponding control cell sets).
- (perturbation embeddings).
- (optional batch covariates).
Here, and are the dimensionalities of the
perturbationandbatch embeddings, respectively. ForSTATE,one-hot encodingsare used, so is the number of unique perturbations and is the number of unique batch labels.
The ST model takes (control cell sets) and (perturbation embeddings) as input and learns to predict (the corresponding perturbed states). This means ST learns to transform control cell populations into their predicted perturbed states.
4.2.2.3. Neural Network Modules
ST uses specialized encoders to map cellular expression profiles, perturbation labels, and batch labels into a shared hidden dimension , which serves as the input to the transformer.
-
Control Cell Encoder: Each
log-normalized expression vectoris mapped to anembeddingvia a 4-layerMLP(Multi-Layer Perceptron) withGELU(Gaussian Error Linear Unit) activations. ThisMLP, denoted as , is applied to each cell independently across the entire control tensor. This transforms the input shape from () to (). -
Perturbation Encoder:
Perturbation labelsare encoded into the sameembedding dimension. Forone-hot encoded perturbations, the input vector is passed through a 4-layerMLPwithGELUactivations: This transforms the input shape () to (). Note that theperturbation embeddingis the same for all cells within the same set of a given batch. Ifperturbationsare represented by continuous features (e.g.,molecular descriptors), these embeddings are used directly for , and is set to . -
Batch Encoder: To account for
technical batch effects,batch labelsare encoded intoembeddingsof dimension : Here, is anembedding layer, transforming the input shape from () to (). -
Transformer Inputs and Outputs: The final input to
STis constructed by summing thecontrol cell embeddingswith theperturbationandbatch embeddings: This composite representation is then passed to thetransformer backboneto modelperturbation effectsacross thecell set. The output is computed as: where represents the final output. Thisresidual connectionformulation encourages thetransformerto learnperturbation effectsas differences or residuals to the input representation . -
Gene Reconstruction Head: When
SToperates directly ingene expression space, agene reconstruction headmaps thetransformer's output back to thegene expression space. This is achieved using alinear projection layer, applied independently to the -dimensional hidden representation of each cell. For a specific batch , where is thetransformeroutput for cell , the predictedgene expressionis given by: Here, and arelearnable parameters. This operation transforms the hidden representations, yielding reconstructedlog-transformed gene expression valuesfor each cell in the batch.
4.2.2.4. Learning Perturbation Effects with Maximum Mean Discrepancy
ST is trained to minimize the discrepancy between predicted () and observed () transcriptomic responses. This discrepancy is quantified using the Maximum Mean Discrepancy (MMD) [57].
For each mini-batch element , ST considers the set of predicted cell expression vectors () and observed cell expression vectors (). The squared MMD between these two sets is computed as:
where:
- : Predicted cell expression vectors from .
- : Observed cell expression vectors from .
- : The
kernel function. The three terms represent: (1)similarity within the predicted set, (2)similarity within the observed set, and (3)cross-similarity between predicted and observed sets.
The paper uses the energy distance kernel:
This kernel is implemented via the geomloss library [59].
For a training mini-batch of cell sets, the batch-averaged MMD loss is:
Minimizing this loss encourages the model to generate sets of perturbed cell expression vectors whose overall statistical properties match those of the observed cell sets, as captured by the MMD.
The total loss for ST is:
4.2.2.5. Training ST in Embedding Spaces
ST offers the flexibility to be trained either directly in gene expression space (using top 2,000 highly variable genes (HVGs), referred to as ) or in a specified embedding space.
When trained in an embedding space:
-
The
dimensionalityof theembedding spaceis , where typically . -
The input tensors and become and , with dimension .
-
The
control cell encoderis modified to transform from () to (). -
The
gene reconstruction headis modified to transform from () to (), producing . -
All other encoders and the
transformerremain unchanged.To recover the original
gene expressionfrom theembedding space, an additionaldecoder headis trained. This is amulti-layer MLPwithdropoutthat maps from theembedding spaceback to the fullgene expression space: This transforms the predicted embeddings from () to ().
The total loss when training in embedding space is a weighted combination of MMD losses in both the embedding and gene expression spaces:
The expression loss is down-weighted by a factor of 0.1 to balance the terms and prioritize learning perturbation effects in the embedding space.
4.2.3. State Embedding Model (SE)
SE is a self-supervised model pre-trained to learn high-quality cell representations from single-cell RNA sequencing data. These embeddings then serve as inputs to ST, enabling more robust transfer across datasets.
4.2.3.1. Gene Representation via Protein Language Models
The SE model incorporates gene representations obtained from pretrained protein language models.
Gene embeddingsare computed usingESM-2(esm2_t48_15B_UR50D[60]). This involves averagingper-amino-acid embeddingsfor eachprotein-coding transcriptin a gene, and then averaging across all transcripts for that gene. This capturesevolutionaryandfunctional relationshipsbetween genes.- These
gene embeddingsare then projected into the model'sembedding dimensionvia alearnable encoder: where:- and are
learnable parameters. SiLU(Sigmoid Linear Unit) is theactivation function[63].LayerNormisLayer Normalization.
- and are
4.2.3.2. Cell Representation
Each cell is represented as a sequence of its most highly expressed genes.
- An "
expression set" is constructed by selecting the top genes ranked bylog fold expression level. - This set is augmented with two special tokens:
where:
- : A
learnable classification tokenused to aggregatecell-level information. - : A
learnable dataset tokento helpdisentangle dataset-specific effects. - : The
projected embeddingof the -th most highly expressed gene in cell .
- : A
- Gene selection is
cell-specific. If a cell expresses fewer than genes, theexpression setis padded to length by randomly sampling fromunexpressed genesto maintainfixed-length input.
4.2.3.3. Expression-Aware Embeddings
To incorporate expression values directly, SE uses an expression embedding scheme inspired by soft binning [12]. For the -th most expressed gene in cell , with expression value , a soft bin assignment is computed:
where:
- : Consists of two
linear layers(dimensions ) withLeakyReLU activation. - : A set of
learnable embeddingsof dimension . The resultingexpression embeddingsare added to the correspondinggene identity embeddings:
4.2.3.4. Transformer Encoding
The input expression set (composed of expression-aware gene embeddings and special tokens) is passed through the transformer encoder :
Here, is as defined in the previous step. Note that expression encodings are set to zero for the CLS and DS tokens.
- The
cell embeddingis extracted from theCLS tokenat position 0 and normalized: This embedding serves as a summary representation of the cell'stranscriptomic state. - Similarly, the
dataset representationis extracted from theDS tokenat position : This embedding captures and accounts fordataset-specific effects.
The final cell embedding is the concatenation of these two quantities:
where projects to a 10-dimensional space. This cell embedding serves as the input representation for individual cells in ST.
4.2.3.5. Pretraining Objectives
SE is trained with a self-supervised learning framework using two objectives: gene expression prediction and an auxiliary dataset classification task.
-
Gene Expression Prediction: The model predicts
expression valuesfor a selected set of 1,280 genes per cell. Target genes are drawn from three categories to ensure coverage across theexpression dynamic range:- : 512
highly expressed genes(from the top genes). - : 512
unexpressed genes(randomly sampled from genes not in the top ). - : 256 genes randomly sampled from the full gene set, shared for all cells in the batch.
This strategy encourages
reconstruction fidelityfor both expressed and silent genes.
- : 512
-
Expression Prediction Decoder: An
MLP decodercombines multiple sources of information to predict gene expression: where:- : The learned
cell embedding. - : The
embeddingof the target gene. - : A scalar
read depth indicator(mean log expression of expressed genes in the input expression set). These are concatenated and passed through (two skip-connected blocks and a linear output layer) to predict thelog expression.
For each cell in a batch, let be the predicted and be the true
expression valuesfor the 1,280 target genes. Thegene-level lossis: This measures thesimilarity between predicted and true gene expression patternswithin each cell.To capture
variation across cells, acell-level lossis computed using the shared subset of genes . Let and be the predicted and true expression values for genes in in cell . These are stacked across cells and transposed. Thecell-level lossis: This measures thesimilarity between predicted and true gene expression across cellsfor each gene in .The
final training loss for expression predictioncombines both axes: - : The learned
-
Dataset Classification Modeling: An
auxiliary dataset prediction taskhelpsdisentangle technical batch effectsfrombiological variation. Using theDS token embedding, the model predicts thedataset of origin: The loss for this task isCrossEntropy: where is the truedataset label. This encourages the model to pooldataset-specific informationinto this token, disentangling it from truebiological signal. -
Total Loss: The
SEmodel is trained using a combination of both losses:
4.2.4. Theoretical Analysis of ST and Optimal Transport
The paper also provides a theoretical analysis of ST's capacity to learn optimal transport (OT) mappings between cellular distributions.
-
Background on Optimal Transport:
OTaims to find a mapping or coupling that minimizes a fixed cost between two probability distributions [74, 75].Neural OTuses neural networks to parameterize and solveOT problems[9, 27]. -
ST as a Transformation Learner: While
STdoesn't explicitly solve anOT problem, it learns a transformation that alignsunperturbedandperturbed cell distributions. The authors connectSTto recent work showing thattransformerscan implementgradient descentforOTthrough engineered prompts [76, 77]. -
Asymptotic Behavior and Solution Family: The analysis focuses on the
asymptotic settingwhere thecell set size ()tends to infinity. In this limit,ST'sself-attention mechanismprocesses information from the entire distribution. The model defines an operator that maps from the control distribution to the predicted perturbed distribution . Here, represents the transformation learned by ST for a given perturbation and batch. -
Lemma 1 (MMD and Distributional Matching): The paper states that if
STachieves zeroempirical MMDwith theenergy kernelas , then the predicted distribution equals the true perturbed distribution with probability 1, and vice versa. Theempirical MMD() is aV-statistic[79, 80] and converges almost surely to the theoreticalMMD(). With theenergy kernel, , and zeroenergy distanceimplies equal distributions [81]. -
Theorem 2 (Optimal Transport Mapping within the Solution Family of STATE): Under certain
regularity conditions(densities are absolutely continuous and bounded, support sets are strictly convex and compact with boundary), the unique continuousoptimal transport mapfrom to (associated with thesquared distance cost) is contained within the solution family ofST(i.e., mappings such that ) with probability 1. This implies thatSTcan learn theoptimal transport mapif it perfectly minimizes theMMDloss. However, the solution family (the set of functions ST can learn) contains infinite additional elements beyond , meaning theoptimal transport solutionis not guaranteed to be learned. -
Theorem 3 (Constrained ST Model for Unique OT Map): If explicit constraints are imposed on the
Jacobianof theSTmodel's mapping (specifically, that theJacobianis symmetric and semi-positive definite), then the constrained solution family uniquely contains the continuousOT mappingwith probability 1. This condition (, where is a convex function) is the definition of anoptimal transport mapfor the quadratic cost, as perBrenier's theorem[83].
This theoretical analysis suggests that ST has the inherent capacity to learn fundamental cellular transformations akin to optimal transport, especially under certain architectural or regularization constraints. The paper hypothesizes that the implicit bias [85-87] of gradient descent in training ST may naturally lead to solutions resembling optimal transport mappings by favoring "minimal shifts."
4.3. ST Training Details
The ST model is implemented using PyTorch Lightning with distributed data parallel (DDP) training and automatic mixed precision (AMP).
- Hyperparameters: Key hyperparameters include
cell_set_size,hidden_dim,n_encoder_layers,n_decoder_layers,batch_encoder(boolean),transformer_backbone_key(LLaMA or GPT2),attn_heads, andparams. - Transformer Backbone: Uses either a
LLaMA[66] orGPT2[67] backbone fromHuggingFace.GPT2is preferred forsparser datasetslikeReplogle-Nadigdue toLayerNorm's properties inlow-data regimes[68]. - Attention: All models use
bi-directional attention. - Positional Encodings: No
positional encodingsare used sincecell orderwithin each set is arbitrary. - Dropout: Not applied within the
transformer. - Initialization: Before training, weights are initialized using
Kaiming Uniform[69], except for thetransformer backbonewhich is initialized from . - Fine-tuning: For
fine-tuning tasks, a newST modelis initialized withpretrained weights.Perturbation encoder (f_{\mathrm{pert}})is reinitialized fortransfer across perturbation modalities. IfSTusescell embeddings, thegene decoder (f_{\mathrm{decode}})is also reinitialized to adapt togene coverage differences. Other components (, , ) retainpretrained weightsand arefine-tuned.
4.4. SE Training Details
The SE model is a 600M parameter encoder-decoder model.
- Architecture: The
encoderhas 16transformer layers, each with 16attention headsandhidden dimension. Each layer usespre-normalization, afeed-forward networkexpanding to , andGELU activation.Dropoutwith probability 0.1 is applied toattentionandfeed-forward layers. Thedecoderis anMLP. - Optimizer:
AdamWoptimizer [70] with a maximum learning rate of , weight decay of 0.01, andgradient clippingwithzclip[71]. - Learning Rate Schedule:
Linear warmupfor the first 3% of steps, followed bycosine annealingto 30% of the maximum learning rate. - Initialization: All
SE weightsinitialized fromKaiming Uniform. - Training Scale: Trained on 14,420
AnnData filesspanning 167 million human cells acrossArc scBaseCount[32],CZ CELLxGENE[31], andTahoe-100M[30] datasets for 4 epochs. - Data Leakage Prevention: Datasets split into separate
trainingandvalidation setsat the dataset level. - Efficiency: Utilizes
Flash Attention 2[72] andmixed precision (bf16) training[73] for efficient, large-scale training. - Hardware: Distributed across 4 compute nodes, each with 8
NVIDIA H100 GPUs. - Batch Size: Effective batch size of 3,072 (per-device batch size of 24,
gradient accumulationover 4 steps).
5. Experimental Setup
5.1. Datasets
5.1.1. Datasets Used for ST Training
STATE was evaluated on several large-scale single-cell perturbation datasets:
-
Tahoe-100M[30]: A large dataset focused ondrug-based perturbations. -
Replogle-Nadig[4, 43]: Containsgenome-scale genetic perturbations. -
Parse-PBMC[42]: Featurescytokine signaling perturbations. -
Jiang[64]: Anothergenetic perturbationdataset. -
McFaline[65]: Containschemical genomicsdata. -
Srivatsan[41]: Anotherchemical transcriptomicsdataset.The following are the results from Table 1 of the original paper:
Dataset # of Cells # of Perturbations # of Contexts Replogle-Nadig 624,158 1,677 4 Jiang 234,845 24 30 Srivatsan 762,795 189 3 Mcfaline-Figueroa 354,758 122 3 Tahoe-100M 100,648,790 1138 50 Parse-PBMC 9,697,974 90 12/18 (Donors)/(Cell Types)
Preprocessing for ST Training:
- All datasets were filtered to retain measurements for 19,790 human
protein-coding Ensembl genes. Normalization: Subsequently normalized to a totalUMI (Unique Molecular Identifier)depth of 10,000.Log-transformation: Raw count data werelog-transformedusingscanpy.pp.log1p.Highly Variable Genes (HVGs): For analyses involvingHVGs, the top 2,000HVGswere identified per dataset usingscanpy.pp.highly_variable_genes. Theselog-transformed HVGvalues were used asgene-level features.PCA Embeddings:PCA (Principal Component Analysis)embeddings of cells were computed usingscanpy.pp.pca.
5.1.2. Additional preprocessing for genetic perturbation datasets
Specific filtering steps were applied to genetic perturbation datasets to ensure high knockdown efficacy:
- Perturbation-level filtering: Only perturbations (excluding controls) whose average
knockdown efficiencyresulted in aresidual expressionwere retained. - Cell-level filtering: Within these selected perturbations, only cells individually meeting a stricter
knockdown threshold(residual expression ) were kept. - Minimum cell count: Any perturbations with fewer than 30 remaining valid cells were dropped, while
control cellswere always preserved.
5.1.3. Datasets Used for SE Training
The SE model was trained on an even larger corpus of observational data to learn robust cell embeddings.
-
SEwas trained on 167 million human cells across three datasets:Arc scBaseCount[32],CZ CELLxGENE[31], andTahoe-100M[30].The following are the results from Table 2 of the original paper:
Dataset # Training Cells # Validation Cells Arc scBaseCount [32] 71,676,369 4,137,674 CZ CellXGene [31] 59,233,790 6,500,519 Tahoe-100M [30] 36,157,383 2,780,587
Preprocessing for SE Training:
- To avoid
data leakageincontext generalization benchmarks,SEwas trained on 20Tahoe-100M cell linesseparate from five held-out cell lines used in benchmarking. scBaseCountdata was filtered to include only cells with at least 1,000non-zero expression measurementsand 2,000UMIs per cell.- A subset of
AnnData fileswas reserved for computingvalidation loss.
5.2. Evaluation Metrics
The evaluation of STATE focuses on its ability to discriminate perturbation effects, accurately identify differentially expressed genes, and generalize across contexts.
5.2.1. Perturbation Evaluation Metrics
- Perturbation Discrimination Score (
PDiscNorm):- Conceptual Definition: This metric assesses how well a model's predicted
pseudobulk(average) expression profile for a perturbation matches the truepseudobulkprofile for that perturbation, compared to other perturbations. A perfect score means the true profile is uniquely closest to its prediction. - Mathematical Formula:
For any perturbation , let denote the predicted pseudobulked expression and denote the observed pseudobulked expression. Using a distance metric (Manhattan or Euclidean distance in this case), is defined as:
The
per-perturbation scorePDisc_t`` is: Theoverall PDiscis the mean: Thenormalized inverse perturbation discrimination scoreis reported: - Symbol Explanation:
- : An index for a specific perturbation.
- : The total number of distinct perturbations.
- : The model's predicted
pseudobulked expressionfor perturbation . - : The true (observed)
pseudobulked expressionfor perturbation . - : A distance function (e.g., Manhattan or Euclidean distance) between expression profiles.
- : An indicator function that is 1 if the condition inside is true, and 0 otherwise.
- : The number of other perturbations () whose true
pseudobulked expressionis closer to the predicted than the correct true profile . - : The normalized rank of the true perturbation profile among all profiles for a given prediction. It ranges from
[0, 1), with 0 being a perfect match. - : The average
PDisc_t`` across all perturbations. - : The final reported score, normalized so that a
random predictorgets a score of 0.0, and aperfect predictorgets a score of 1.0. Higher values indicate better discrimination.
- Conceptual Definition: This metric assesses how well a model's predicted
5.2.2. Differential Expression
-
DE Overlap Accuracy:
- Conceptual Definition: This metric quantifies the overlap between the set of
differentially expressed genes (DEGs)identified by the model's predictions and theDEGsidentified from the true observed data. It measures how accurately the model can pinpoint the genes whose expression changes. - Mathematical Formula:
For each perturbation , denotes the set of top true
DEGs, and denotes the set of top predictedDEGs. The overlap is computed as: When , represents the total number ofDEGsin the true set. - Symbol Explanation:
- : An index for a specific perturbation.
- : The number of top
DEGsconsidered (e.g., 50, 100, 200, or for allDEGs). - : The set of top
differentially expressed genesidentified from the true data for perturbation . Genes are filtered byadjusted p-value() and ranked byabsolute log fold change. - : The set of top
differentially expressed genesidentified from the model's predicted data for perturbation . - : The size of the intersection between set A and set B (i.e., the number of genes common to both the true and predicted
DEGsets). - : The fraction of overlap for perturbation at a given . Higher values indicate better accuracy.
- Conceptual Definition: This metric quantifies the overlap between the set of
-
Effect Sizes (
SizeCorr):- Conceptual Definition: This metric evaluates how well the model can predict the magnitude of a perturbation's effect by correlating the number of
differentially expressed genesin the true data with the number in the predicted data. It measures the ability to rank perturbations by their overall impact. - Mathematical Formula:
Let be the number of true
DEGsfor perturbation , and be the number of predictedDEGsfor perturbation . - Symbol Explanation:
- : The total number of
differentially expressed genes(adjusted p-value ) identified from the true data for perturbation . - : The total number of
differentially expressed genes(adjusted p-value ) identified from the model's predicted data for perturbation . - : The vector of true
DEG countsacross all perturbations. - : The vector of predicted
DEG countsacross all perturbations. - : The
Spearman rank correlation coefficient, which measures the monotonic relationship between the two lists ofDEG counts. A higherSpearman correlationindicates better agreement in rankingperturbation effect sizes.
- : The total number of
- Conceptual Definition: This metric evaluates how well the model can predict the magnitude of a perturbation's effect by correlating the number of
5.2.3. Cell Embedding Evaluation Metrics
-
Intrinsic Evaluation:
- Conceptual Definition: This assesses the quality of the
cell embeddingsthemselves by testing how muchperturbation-specific informationthey retain. A shallowMLP classifieris trained to predict theperturbation labeldirectly from thecell embedding. High classification performance suggests that the embeddings capture distinct perturbation-induced states. - Mathematical Formula: Not explicitly provided for this specific evaluation, but typically involves
AUROC(Area Under the Receiver Operating Characteristic curve) andaccuracyfor amulti-class classificationtask withcross-entropy loss. - Symbol Explanation:
AUROCmeasures the ability of a classifier to distinguish between classes.Accuracyis the proportion of correctly classified instances.
- Conceptual Definition: This assesses the quality of the
-
Extrinsic Evaluation:
- Conceptual Definition: This evaluates the practical utility of
cell embeddingsby using them as input for theSTmodel and assessingST's downstream performance. It aims to see if embeddings that are rich in information (high intrinsic score) also lead to betterperturbation predictionbyST. - Mathematical Formula: Not explicitly provided, but involves comparing the performance of
ST(using the metrics described above, likePDiscNorm,DE Overlap,SizeCorr) when trained with differentembedding spaces. - Symbol Explanation:
ST's performance metrics are used here to compare the utility of differentcell embeddings.
- Conceptual Definition: This evaluates the practical utility of
5.2.4. Other Evaluations
- Cell Set Scaling:
- Conceptual Definition: An ablation study to understand the impact of the
cell set size ()onSTATE's performance. It investigates how varying affectsvalidation losswhile controllingcomputational cost(FLOPs). This also comparesSTATEto simpler models like apseudobulk model(mean pooling instead ofself-attention) and asingle-cell variant(). - Mathematical Formula: Not explicitly provided, but involves plotting
validation lossagainstFLOPsfor differentcell set sizes. - Symbol Explanation:
cell set size: The number of cells grouped together in one set forSTprocessing.batch size: The number ofcell setsprocessed concurrently.FLOPs:Floating-point operations, a measure of computational cost.validation loss: The model's error on a held-out validation dataset.
- Conceptual Definition: An ablation study to understand the impact of the
5.2.5. Task Description for Generalization
The paper defines two generalization tasks:
-
Underrepresented Context Generalization Task:
- Conceptual Definition: This task evaluates a model's ability to generalize to a cellular context where only a small fraction of perturbations have been observed during training. It simulates scenarios where limited data is available for a new cell type or condition.
- Procedure: For a chosen test context , a small proportion () of its perturbations form a
support setfor training, and the rest are held out as atarget setfor testing. The model trains on all perturbations from other contexts and the support set from . It is then evaluated on the target set of . - Data Splits:
Support set, where .Target set.Training data.Test data.
- Notes: Unperturbed cells from all contexts (including ) are available during training. If is large, a fixed subset is used. If is small, an
iterative leave-one-outapproach is used.
-
Zero-Shot Context Generalization Task:
- Conceptual Definition: This task assesses a model's ability to predict perturbation effects in entirely novel cellular contexts for which no perturbation data was seen during fine-tuning. It relies on the model's ability to learn generalizable relationships.
- Procedure: The model is
pre-trainedon largeperturbation datasets(e.g.,Tahoe-100M). Then, for aquery dataset, one context is held out (), and the remaining contexts () are used forfine-tuning. Thefine-tuned modelis evaluated on its ability to predict perturbations within the held-out context of thequery dataset. - Data Splits:
Fine-tuning data.Test data.
- Notes: Unperturbed cells from all contexts (including
held-out) are available.
5.3. Baselines
STATE was benchmarked against both simple statistical methods and state-of-the-art deep learning models.
-
Perturbation Mean Baseline [767]:
- Conceptual Definition: This baseline assumes that the effect of a perturbation is a global average shift applicable across all cell types. It predicts a perturbed profile by adding a global perturbation-specific offset (calculated from training data) to the control mean of the target cell type.
- Mathematical Formula:
For each cell type and perturbation :
Cell-type specific control mean: , where is the set of control cells of type .Cell-type specific perturbed mean: , where is the set of perturbed cells of type receiving perturbation .Cell-type specific perturbation offset: .Global perturbation offset: , where . Given a test cell type and perturbation label , the model predicts:
- Symbol Explanation:
- : Set of control cells of cell type .
- : Set of cells of cell type perturbed with .
- : Expression vector of cell .
- : Mean expression of control cells in context .
- : Mean expression of cells perturbed with in context .
- : Difference between perturbed and control means for context and perturbation .
- : Set of contexts where perturbation was observed.
- : Global average offset for perturbation across contexts.
- : Predicted expression profile.
- : Control mean for the test cell type .
- : Offset for control perturbations (defined as zero).
-
Context Mean Baseline [775]:
- Conceptual Definition: This baseline predicts a cell's post-perturbation profile by simply returning the average perturbed expression observed in the training set for cells of the same cell type. If the cell is a control, its original profile is passed through.
- Mathematical Formula:
For every cell type , the
pseudo-bulk meanof all non-control perturbed cells in the training set is formed: At inference time, for a test cell with cell type and perturbation label , the prediction is: - Symbol Explanation:
- : Mean expression of all non-control perturbed cells in context from the training set.
- : Set of training cells of type that are perturbed (not controls).
- : Expression vector of cell .
- : Cell type of cell .
- : Perturbation label of cell .
- : Predicted expression profile for cell .
-
Linear Baseline [780] (as described in [46]):
- Conceptual Definition: This model treats a perturbation as a
low-rank, gene-wide linear displacementadded to a cell's control expression. It learns a linear map that transforms gene embeddings and perturbation embeddings into an expression change. - Mathematical Formula:
- : Fixed
gene-embedding matrix(e.g., pretrained protein feature vectors). - : Fixed
perturbation-embedding matrix(one-hot). - From the training set, an "
expression-change" pseudobulk is built: So stores the average change relative to that cell's matched control for every gene and perturbation . - The model seeks a
low-rank mapand agene-wise biassuch that . - is obtained by solving the
ridge-regularized least-squares problem: - Its
closed-form solutionis: - For a test cell with control profile and perturbation label , the prediction is:
- : Fixed
- Symbol Explanation:
- : Gene embedding matrix.
- : Perturbation embedding matrix.
- : Average expression change for gene under perturbation .
- : Set of cells receiving perturbation .
- : Expression of gene in perturbed cell .
- : Expression of gene in control cell .
- : Low-rank map.
- : Gene-wise bias.
- : Vector of ones.
- : Frobenius norm squared.
- : Ridge regularization parameter.
- : Identity matrix.
- : Row of corresponding to perturbation .
- Conceptual Definition: This model treats a perturbation as a
-
Deep Learning Baselines [793]:
scVI[25]: Anautoencoder-based modelthat modelsgene expression distributionswhile accounting fortechnical noiseandbatch effects.CPA[8, 47]: Anautoencoder-based modelthat learns acompositional latent spaceto capture additive effects ofperturbation,dosage, andcell type.scGPT[11]: Atransformer-based foundation modelthat usesgenerative pretrainingon large datasets (over 33 million cells) to achievezero-shot generalizationacross tasks, includingperturbation prediction.
5.4. ST Hyperparameters
The following are the results from Table 3 of the original paper:
| Dataset | cell_set_size | hidden_dim | n_encoder layers | n_decoder layers | batch encoder | transformer backbone _key | attn_heads | params |
|---|---|---|---|---|---|---|---|---|
| Tahoe-100M | 256 | 1488 | 4 | 4 | false | LLaMA | 12 | 244M |
| Parse-PBMC | 512 | 1440 | 4 | true | LLaMA | 12 | 244M | |
| Replogle-Nadig | 32 | 128 | 4 | 4 | false | GPT2 | 8 | 10M |
The following are the results from Table 4 of the original paper:
| Component | Architecture | Layer Dimensions | Activation | Normalization | Dropout |
|---|---|---|---|---|---|
| fcell | 4-layer MLP | (G or E) → h → h → h → h | GELU | LayerNorm | None |
| fpert | 4-layer MLP | Dpert → h → h → h → h | GELU | LayerNorm | None |
| fbatch | Embedding Lyr. | Dbatch → h | N/A | N/A | None |
| fsT | LLaMA Transf. | h (input to each of 4 layers) | SwiGLU | RMSNorm | None |
| recon | Linear Layer | h → G or E | N/A | N/A | None |
| fdecode | 3-layer MLP | h → 1024 → 512 → G | GELU | LayerNorm | 0.1 |
| fconf | 3-layer MLP | h → h/2 → h/4 → 1 | GELU | LayerNorm | None |
Notes on Architectural Details:
- : Number of genes.
- : Dimension of cell embeddings (if used).
- : Dimension of perturbation features.
- : Dimension of batch features.
- : Shared hidden dimension.
- All MLPs use
GELU activationandLayer Normalizationbefore activation, unless otherwise specified. f_ST(transformer backbone) usesSwiGLU(Swish-Gated Linear Unit) activation andRMSNorm(Root Mean Square Normalization).
6. Results & Analysis
6.1. Core Results Analysis
The experimental results consistently demonstrate STATE's superior performance in predicting cellular responses to perturbations, particularly highlighting its ability to generalize across diverse contexts and leverage large datasets.
The multi-panel Figure 2 from the original paper, titled "STATE improves perturbation prediction in context generalization, leverages data scale, and enables cross-dataset transfer learning," provides visual evidence for these claims.
Figure 2: STATE improves perturbation prediction in context generalization, leverages data scale, and enables cross-dataset transer learning. (A) Underrepresented context generalization task.Models weretrained o pa o ec perubationslargely eldout ndeprenttare cntext.B)Models weeraine an evalut aa 304 dtas . peo arehmodel perubati atn [4], vin predicted vs true DEs, and Spearman correlation of perturbations by efect ize (numberof predicted DEGs). (D) Equl acuracy perturbed cells in the embedding gnerated using observed data; extrinsic performance measures the clasifcation accuracy over perturbed embeddings predicted by ST trained on the cel embeddings. (E) Zero-shot pe t , after pre-training on Tahoe-100M.
Performance in Context Generalization (Figure 2C)
STATE was evaluated on several metrics in a few-shot context generalization task, where each test cell context contributed only 30% of its perturbations to training (Figure 2B).
- Perturbation Discrimination Score (
PDiscNorm):STATEachieved substantial absolute improvements:- 54% on the
Tahoe-100Mdataset (a large drug perturbation dataset). - 29% on the
Parse-PBMCdataset (a cytokine signaling perturbation dataset). - It also matched the next best performing baseline on
genetic perturbations. This indicatesSTATE's strong ability to distinguish between different perturbation effects, even with limited exposure in a new context.
- 54% on the
- DE Overlap Accuracy:
STATEshowed significantly improved accuracy in identifyingdifferentially expressed genes (DEGs):- Twice as good as the next best baseline on
Tahoe-100M. - 43% better on the
Parse-PBMCdataset. - It was the second best model on
genetic perturbations. This highlightsSTATE's biological relevance, as accurately predictingDEGsis crucial for understanding molecular mechanisms.
- Twice as good as the next best baseline on
- Effect Sizes (
SizeCorr):STATEaccurately ranked perturbations by their relative effect sizes, demonstrating its capacity to predict the overall impact of perturbations:- 53% higher Spearman correlations on
Parse-PBMC. - 22% higher than baselines on
ReplogleNadig. - 70% higher on
Tahoe-100M, approaching an absolute correlation of 0.8.
- 53% higher Spearman correlations on
Leveraging Data Scale
A notable finding is that STATE provides the biggest benefit as the data scale increases. While performing well on genetic perturbations (where effect sizes are weaker and data is less abundant), its performance is significantly better on signaling perturbations and drug perturbations (which represent one and two more orders of magnitude of data, respectively). This suggests that previous perturbation models are in a data-poor regime [49, 50], and STATE's architecture is better designed to leverage vast amounts of data effectively.
Shared Cell Embedding (SE) for Cross-Dataset Transfer (Figure 2D and 2E)
- Information Content in SE (Figure 2D):
SE'sembeddingswere evaluated for theirintrinsicandextrinsicperformance.Intrinsic evaluationmeasured how well a classifier could predictperturbation labelsfrom the embeddings.Extrinsic evaluationmeasuredST's performance when using these embeddings.STATE embeddingssignificantly improved over other models in both aspects, outperforming even dataset-specific representations usinghighly-varying genes. This indicates thatSEeffectively capturesperturbation-specific informationwhile being robust. - Zero-Shot Transfer Learning (Figure 2E):
STATEmodels () were pretrained onTahoe-100Mandfine-tunedon smaller datasets forzero-shot context generalization. Across all tested datasets, enabled better transfer than (ST trained onHighly Variable Genes) and consistently outperformed other baseline methods. This demonstrates the critical role ofSEin enablingSTATEto generalize to novel cellular contexts where no perturbation data was seen during fine-tuning, paving the way forvirtual cell models.
6.2. Data Presentation (Tables)
The following are the results from Table 1 of the original paper, summarizing datasets used in ST experiments:
| Dataset | # of Cells | # of Perturbations | # of Contexts |
|---|---|---|---|
| Replogle-Nadig | 624,158 | 1,677 | 4 |
| Jiang | 234,845 | 24 | 30 |
| Srivatsan | 762,795 | 189 | 3 |
| Mcfaline-Figueroa | 354,758 | 122 | 3 |
| Tahoe-100M | 100,648,790 | 1138 | 50 |
| Parse-PBMC | 9,697,974 | 90 | 12/18 (Donors)/(Cell Types) |
The following are the results from Table 2 of the original paper, summarizing datasets used in SE experiments:
| Dataset | # Training Cells | # Validation Cells |
|---|---|---|
| Arc scBaseCount [32] | 71,676,369 | 4,137,674 |
| CZ CellXGene [31] | 59,233,790 | 6,500,519 |
| Tahoe-100M [30] | 36,157,383 | 2,780,587 |
The following are the results from Table 3 of the original paper, detailing key model hyperparameters by dataset:
| Dataset | cell_set_size | hidden_dim | n_encoder layers | n_decoder layers | batch encoder | transformer backbone _key | attn_heads | params |
|---|---|---|---|---|---|---|---|---|
| Tahoe-100M | 256 | 1488 | 4 | 4 | false | LLaMA | 12 | 244M |
| Parse-PBMC | 512 | 1440 | 4 | true | LLaMA | 12 | 244M | |
| Replogle-Nadig | 32 | 128 | 4 | 4 | false | GPT2 | 8 | 10M |
The following are the results from Table 4 of the original paper, providing architectural details for ST components:
| Component | Architecture | Layer Dimensions | Activation | Normalization | Dropout |
|---|---|---|---|---|---|
| fcell | 4-layer MLP | (G or E) → h → h → h → h | GELU | LayerNorm | None |
| fpert | 4-layer MLP | Dpert → h → h → h → h | GELU | LayerNorm | None |
| fbatch | Embedding Lyr. | Dbatch → h | N/A | N/A | None |
| fsT | LLaMA Transf. | h (input to each of 4 layers) | SwiGLU | RMSNorm | None |
| recon | Linear Layer | h → G or E | N/A | N/A | None |
| fdecode | 3-layer MLP | h → 1024 → 512 → G | GELU | LayerNorm | 0.1 |
| fconf | 3-layer MLP | h → h/2 → h/4 → 1 | GELU | LayerNorm | None |
6.3. Ablation Studies / Parameter Analysis
The paper specifically highlights an ablation study on Cell Set Scaling (Figure 1C from the original paper, as described in the image alt text).
Figure 1: StATE is a multi-scale machine learning architecture that operates across genes, individual cells, and cell populations. A) The State Transition model (ST) learns perturbation effects by training on sets o pubea e hat erti ye bae l p p on the Tahoe-100M dataset achieved when covariate-matched groups are chunked into sets of 256 cells. The full ST model significantly outperforms a pseudobulk model (STATE w/ mean-pooling instead of sel-attention) and a single-cell variant (STATE with set size ). Removing the self-attention mechanism (STATE w/o self-attention) substantialy degrades performance.C) The StateEmbedin model (SE) is an encoder-decodermodel traieon laralsevational ng-cerntomiataTheranormeencoonstructdensbe the cel and the MLP decoder reconstructs gene expression from the embedding. ST can operate directly on gene exression profl rn el representations from SWhen trained with SE, a separate MLPecodes fro preic embeddings to perturbed gene expression.
This study investigated the impact of the cell set size in STATE.
- The authors varied the number of cells per set and the
batch sizesuch that the product (batch sizecell set size) remained constant at 16,384. They measured thevalidation lossas a function offloating-point operations (FLOPs). - Key Findings:
-
As the
cell set sizeincreased, thevalidation losssignificantly improved on held-out cell lines. -
An optimal
cell set sizeof 256 was observed, beyond which returns diminished. This supports theSTmodule's design choice of modeling cell populations rather than individual cells. -
Comparison to Baselines: The full
STmodel significantly outperformed apseudobulk model(which replacedself-attentionwithmean-pooling). This indicates the advantage ofself-attentionin capturingintra-set heterogeneityover simple averaging. -
Importance of Self-Attention: A
single-cell variant(STATEwith ) performed poorly. Crucially,removing the self-attention mechanism(STATE w/o self-attention) substantially degraded performance. This highlights the indispensable role of theself-attention mechanismin enablingSTto effectively modelperturbation-induced transformationsand account forcellular heterogeneity.Another implicit ablation is the comparison between and in
transfer learning(Figure 2E). This comparison demonstrates that trainingSTusingSE's richcell embeddingssignificantly improvescross-dataset transferabilitycompared to usinghighly variable genesdirectly. This validates the design choice ofSEas a crucial component forSTATE's generalization capabilities.
-
7. Conclusion & Reflections
7.1. Conclusion Summary
The paper introduces STATE, a novel multi-scale, transformer-based model that significantly advances the prediction of cellular responses to perturbations across diverse contexts. STATE's architecture comprises two key modules: the State Transition (ST) model, which uses self-attention to learn perturbation effects on sets of cells, thereby capturing cellular heterogeneity; and the State Embedding (SE) model, a pre-trained encoder-decoder that generates robust cell embeddings from vast observational data.
Key findings include STATE's superior performance in perturbation discrimination and differential gene expression identification on large-scale datasets, with improvements exceeding 30% over existing models. The model effectively leverages giga-cell datasets (over 100 million perturbed cells and 167 million observational cells), showing that its performance scales with data availability. Furthermore, STATE demonstrates strong context generalization and zero-shot prediction capabilities in novel cellular environments, which were unseen during training. This is largely attributed to the high-quality SE embeddings and ST's ability to model distributional shifts via MMD loss. The theoretical analysis also establishes a connection between ST's solution family and optimal transport maps, providing a principled understanding of its capacity to learn cellular transformations.
Overall, STATE sets a new benchmark for perturbation response prediction, laying the groundwork for developing scalable virtual cell models for various biological and therapeutic applications.
7.2. Limitations & Future Work
The authors implicitly or explicitly mention several limitations and avenues for future work:
- Theoretical Guarantees for Optimal Transport: While Theorem 2 shows that the
optimal transport mapis withinST's solution family , it also notes that contains infinitely many other elements. This meansSTis not guaranteed to learn theunique optimal transport solutionwithout further constraints. - Constraining the Jacobian: Theorem 3 suggests that imposing explicit constraints on the
JacobianofST's mapping (to be symmetric and semi-positive definite) would ensure it learns theunique OT map. However, implementing these constraints in a deep learning architecture, especially for transformers, is a non-trivial challenge that might require novel architectural designs, similar toInput Convex Neural Networks (ICNNs)[9, 84]. - Implicit Bias Characterization: The authors hypothesize that the
implicit bias[85-87] ofgradient descentduringST's optimization might lead it to discover solutions that resembleoptimal transport mappingsby favoring "minimal shifts." Characterizing thisimplicit biasintransformer-based transition modelsis explicitly stated as a promising area for future research. - Computational Cost: Training
STATErequires immense computational resources, as evidenced by its use ofFlash Attention 2,mixed precision training, anddistributed trainingacross multipleNVIDIA H100 GPUson 100M+ cells. This might limit its accessibility for researchers without high-performance computing infrastructure. - Interpretability: While
transformersare powerful, their complex, black-box nature can make it challenging to fully interpret why specific predictions are made or howheterogeneityis being modeled, which could be important for biological discovery.
7.3. Personal Insights & Critique
This paper presents a highly innovative and impactful approach to a critical problem in systems biology and drug discovery. My personal insights include:
- Holistic Multi-Scale Design: The most compelling aspect of
STATEis its holistic,multi-scale architecture. By integratingprotein language modelsforgene embeddings, a massiveobservational pre-trainingfor robustcell embeddings (SE), and aself-attention-based transformer (ST)forpopulation-level perturbation modeling, it addresses the problem from multiple angles simultaneously. This layered approach is a significant improvement over models that focus solely ongene expressionorsingle-cell representations. - Leveraging Data Scale Effectively: The explicit demonstration that
STATE's performance significantly improves with increasing data scale is a powerful validation. This positionsSTATEas a truefoundation modelforcell biology, capable of extracting deep insights from the ever-growingsingle-cell omicsdatasets. The finding that other models are "data-poor" provides crucial context for the field. - Distributional Approach with MMD: The use of
MMDas a loss function to compare distributions of cell states is elegant and biologically sound, especially given thedestructive natureofscRNA-seq. It allows the model to learnpopulation-level changesandheterogeneityrather than justpoint-wise predictions, which is more reflective of real biological phenomena. - Theoretical Grounding: The theoretical connection to
Optimal Transportis a strong addition, providing a mathematical framework to understandSTATE's capacity. While the current model doesn't explicitly enforceOT constraints, the possibility to do so in future work opens exciting avenues for more principledcellular state transformations. - Potential for Diverse Applications: The model's ability to generalize across
genetic,signaling, andchemical perturbationsand to novel contexts makes it incredibly versatile. It could directly impactdrug repurposing,target identification, and even guideautonomous experimental designsystems by providing accurate in-silico predictions.
Potential Issues or Areas for Improvement:
-
Computational Accessibility: While its scale is a strength, it also presents a barrier. Future work might explore methods for distillation or more efficient architectures to make such powerful models accessible to a wider range of researchers without
giga-scale computational resources. -
Interpretability of
Heterogeneity: WhileSTATEmodelsheterogeneity, it would be valuable to have tools to interpret what kind ofheterogeneityis being captured by theself-attention mechanismandSE embeddings. This could lead to novel biological discoveries about cell state transitions. -
Beyond Gene Expression: The current model focuses on
gene expression. Integrating otheromics modalities(e.g.,epigenomics,proteomics) withinSTATE's multi-scale framework could further enhance its predictive power and provide a more completevirtual cell model. -
Dynamic Modeling:
STATEmodels a static transition from control to perturbed states. Future extensions could explore dynamic modeling of cellular processes over time, which would be crucial for understanding disease progression or drug response kinetics.In conclusion,
STATEis a landmark paper that pushes the boundaries ofsingle-cell perturbation prediction, demonstrating the power of large-scaletransformer modelsand principleddistributional learningto unlock deeper biological understanding and acceleratebiomedical discovery.
Similar papers
Recommended via semantic vector search.