AiPaper
Paper status: completed

Predicting cellular responses to perturbation across diverse contexts with STATE

Published:06/27/2025
Original Link
Price: 0.10
5 readers
This analysis is AI-generated and may not be fully accurate. Please refer to the original paper.

TL;DR Summary

STATE, a transformer model trained on over 100 million cells, predicts perturbation responses across diverse cellular contexts with 30% improved effect discrimination and accurate gene identification, enabling generalization to unseen backgrounds and advancing scalable virtual ce

Abstract

Predicting cellular responses to perturbation across diverse contexts with State Anonymous Author(s) Affiliation Address email Abstract 1 Cellular responses to perturbations are a cornerstone for understanding biological mechanisms and 2 selecting drug targets. While machine learning models offer tremendous potential for predicting perturbation 3 effects, they currently struggle to generalize to unobserved cellular contexts. Here, we introduce State , 4 a transformer model that predicts perturbation effects while accounting for cellular heterogeneity within 5 and across experiments. State predicts perturbation effects across sets of cells and is trained using gene 6 expression data from over 100 million perturbed cells. State improved discrimination of effects on large 7 datasets by more than 30% and identified differentially expressed genes across genetic, signaling and chemical 8 perturbations with significantly improved accuracy. Using its cell embedding trained on observational data 9 from 167 million cells, State identified strong perturbations in novel cellular contexts where no perturbations 10 were observed during training. Overall, the perfo

Mind Map

In-depth Reading

English Analysis

1. Bibliographic Information

1.1. Title

Predicting cellular responses to perturbation across diverse contexts with STATE

1.2. Authors

The authors are listed as "Anonymous Author(s)" in the provided paper draft, which is common for submissions to conferences with a double-blind review process.

1.3. Journal/Conference

The paper is "Submitted to 39th Conference on Neural Information Processing Systems (NeurIPS 2025)." NeurIPS is one of the most prestigious and influential conferences in the field of artificial intelligence and machine learning, indicating a high-profile venue for research dissemination.

1.4. Publication Year

The paper was submitted to NeurIPS 2025, suggesting a target publication year of 2025. However, it is currently a preprint as indicated by its submission status and the anonymous authors.

1.5. Abstract

Cellular responses to perturbations are critical for understanding biological mechanisms and identifying drug targets. Existing machine learning models for predicting these responses struggle to generalize across diverse cellular contexts. This paper introduces STATE, a transformer-based model designed to predict perturbation effects while explicitly accounting for cellular heterogeneity both within and across experiments. STATE operates by predicting perturbation effects across sets of cells and was trained using gene expression data from over 100 million perturbed cells. The model demonstrated significant improvements, increasing effect discrimination by more than 30% on large datasets and achieving substantially higher accuracy in identifying differentially expressed genes across genetic, signaling, and chemical perturbations. Leveraging cell embeddings trained on a vast observational dataset of 167 million cells, STATE successfully predicts strong perturbations in novel cellular contexts that were not observed during its training. The authors suggest that STATE's performance and flexibility pave the way for the development of scalable virtual cell models for various biological applications.

The original source link provided is /files/papers/690878ad1ccaadf40a4344dd/paper.pdf. Given the context of "Submitted to 39th Conference on Neural Information Processing Systems (NeurIPS 2025)" and anonymous authors, this link points to a preprint version of the paper.

2. Executive Summary

2.1. Background & Motivation

The core problem STATE aims to solve is the accurate prediction of cellular responses to various perturbations (e.g., genetic, chemical, signaling) across diverse and often unseen cellular contexts. Understanding these responses is fundamental for basic biological research, disease modeling, and crucial for rational drug target selection in pharmaceutical development.

Despite the rapid growth in the size and scope of perturbation datasets, current machine learning models face significant challenges:

  1. Generalization to Unobserved Contexts: Many models struggle to accurately predict how cells will respond in cellular environments or with perturbations not encountered during training.

  2. Destructive Nature of Single-Cell Assays: Single-cell RNA sequencing (scRNA-seq) measurements destroy the cell, preventing observation of its pre-perturbation state. This makes it difficult to infer individual cell responses directly.

  3. Cellular Heterogeneity: Cells within a population, even under identical conditions, exhibit considerable biological variation. Existing models often oversimplify this within-population heterogeneity or fail to distinguish it from true perturbation signals, especially when perturbation effects are subtle.

  4. Data Scale Limitations: While data is growing, many models are still in a data-poor regime for effectively learning complex perturbation effects that generalize.

    The paper's entry point is to explicitly account for cellular heterogeneity and cross-dataset variability using a novel transformer-based architecture and a vast amount of single-cell data. It proposes to move beyond single-cell prediction towards modeling distributions of cell responses.

2.2. Main Contributions / Findings

The paper makes several significant contributions:

  1. Introduction of STATE: A novel multi-scale, transformer-based machine learning architecture designed for predicting cellular responses to perturbations. STATE consists of two complementary modules:

    • State Transition model (ST): A transformer that models perturbation-induced transformations across sets of cells, explicitly capturing biological heterogeneity and avoiding reliance on explicit distributional assumptions.
    • State Embedding model (SE): A pre-trained encoder-decoder transformer that generates robust cell embeddings by learning gene expression variation across diverse observational datasets, optimized for detecting perturbation effects and reducing technical variation.
  2. Scalable Training with Large Datasets: STATE leverages an unprecedented scale of data, utilizing 167 million cells of observational data to train SE and over 100 million perturbed cells to train ST.

  3. Superior Performance and Generalization:

    • STATE achieved an absolute improvement of 54% and 29% in perturbation discrimination on Tahoe-100M and Parse-PBMC datasets, respectively.
    • It significantly improved differential gene expression (DEG) identification, being twice as good as the next best baseline on Tahoe-100M and 43% better on Parse-PBMC.
    • The model accurately ranked perturbations by their relative effect sizes, showing Spearman correlations 53% higher on Parse-PBMC, 22% higher on ReplogleNadig, and 70% higher on Tahoe-100M compared to baselines.
    • STATE demonstrated strong context generalization and zero-shot prediction capabilities, successfully predicting strong perturbations in novel cellular contexts not seen during training.
  4. Leveraging Data Scale: The performance gains of STATE were most pronounced with increasing data scale, suggesting its architecture is better equipped to utilize large datasets compared to existing models.

  5. Theoretical Foundations for Optimal Transport: The paper provides theoretical analysis suggesting that ST's solution family covers the optimal transport map between control and perturbed cell distributions asymptotically, hinting at its capacity to learn fundamental cellular transformations.

    These findings position STATE as a crucial step towards developing scalable virtual cell models, which could accelerate causal target discovery and deepen the understanding of cellular function and disease.

3. Prerequisite Knowledge & Related Work

3.1. Foundational Concepts

To understand the STATE model and its significance, a foundational understanding of several biological and machine learning concepts is essential.

  • Cellular Perturbations: In biology, a perturbation refers to any intentional modification or intervention applied to a biological system (like a cell) to observe its response. This can include:

    • Genetic perturbations: Manipulating gene expression (e.g., using CRISPR to knock out or activate specific genes).
    • Chemical perturbations: Treating cells with drugs or small molecules.
    • Signaling perturbations: Modulating cellular signaling pathways (e.g., with cytokines). Studying these responses is crucial for understanding biological mechanisms and identifying drug targets.
  • Gene Expression Data / Single-Cell RNA Sequencing (scRNA-seq):

    • Gene expression is the process by which information from a gene is used in the synthesis of a functional gene product, such as a protein. Cells vary widely in which genes they express and at what levels.
    • scRNA-seq is a high-throughput technology that measures the gene expression profile of individual cells. This generates data where each cell is represented by a vector of gene counts or expression levels (e.g., log-normalized counts as used in the paper, which involves applying a logarithmic transformation to the counts to stabilize variance and make distributions more symmetric).
    • A key challenge with scRNA-seq is its destructive nature: measuring a cell's transcriptome destroys it, meaning we cannot observe the same cell's state before and after a perturbation.
  • Machine Learning Models for Perturbation Prediction: The goal is to develop models that can predict how the gene expression profile of a cell (or a population of cells) will change in response to a given perturbation. This involves learning complex relationships between basal cell states, perturbation types, and resulting transcriptomic changes.

  • Transformers and Self-Attention:

    • Transformers are a neural network architecture, originally developed for natural language processing, that have become highly influential due to their self-attention mechanism.
    • Self-attention (or intra-attention) allows the model to weigh the importance of different parts of an input sequence (or set) relative to each other when processing each part. For a set of cell representations, self-attention means that when processing one cell's data, the model can consider and weigh the information from all other cells in the set. This is particularly powerful for capturing relationships and heterogeneity within a set of cells. The general Attention formula is: $ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $ where:
      • QQ (Query), KK (Key), VV (Value) are matrices derived from the input embeddings.
      • QKTQ K^T calculates the attention scores (how much each element relates to others).
      • dk\sqrt{d_k} is a scaling factor, typically the square root of the dimension of the keys, to prevent very large attention scores.
      • softmax normalizes these scores into a probability distribution.
      • The result is a weighted sum of the Value vectors, allowing the model to focus on relevant information.
  • Cell Embeddings: An embedding is a lower-dimensional, dense vector representation of an entity (in this case, a cell) that captures its essential features. Cell embeddings aim to represent a cell's transcriptomic state (its unique gene expression profile) in a way that is robust to technical noise and highlights biological similarities or differences. These embeddings can then be used as input for downstream tasks or as a more abstract space to model cellular processes.

  • Differential Expression Analysis (DEG): This statistical analysis identifies genes whose expression levels change significantly between two or more biological conditions (e.g., perturbed vs. unperturbed cells). Differentially Expressed Genes are crucial indicators of a cell's response to a perturbation.

  • Maximum Mean Discrepancy (MMD): MMD is a statistical test used to determine if two samples are drawn from the same distribution. It measures the distance between the mean embeddings of two distributions in a Reproducing Kernel Hilbert Space (RKHS). Intuitively, if the average function value (embedding) for samples from distribution A is very different from the average function value for samples from distribution B, then the distributions are likely different. The paper uses the energy distance kernel which makes MMD equivalent to the energy distance. A key property of the energy distance is that it is zero if and only if the two distributions are identical.

  • Optimal Transport (OT): Optimal Transport is a mathematical framework for comparing and transforming probability distributions. It seeks the most efficient way to "move" mass from one distribution to another, minimizing a given "cost" (e.g., Euclidean distance between points). In the context of cells, Optimal Transport could theoretically map a distribution of unperturbed cell states to a distribution of perturbed cell states.

3.2. Previous Works

The paper discusses several categories of prior approaches and specific models for perturbation prediction:

  • Mapping Perturbed to Unperturbed Cells: Some earlier methods (e.g., [10-12]) assumed within-population heterogeneity was negligible and tried to map perturbed cells to randomly selected unperturbed cells with shared covariates (like cell type).

    • Limitation: These models often fail when perturbation effects are subtle, as unperturbed population heterogeneity can be larger than the perturbation signal itself, making accurate mapping difficult.
  • Distribution-Based Models: Other models treat cell populations as distributions, aiming to learn data-generating distributions or disentangle labeled and unlabeled sources of variation (e.g., [8, 19-26]).

    • Limitation: In practice, the paper notes that these models often do not significantly outperform methods that don't explicitly model distributional structure [14]. Specific examples mentioned:
      • scVI [25]: A deep generative model that aims to model gene expression distributions while accounting for technical noise and batch effects.
      • CPA (Compositional Perturbation Autoencoder) [8, 47]: An autoencoder-based model that learns a compositional latent space to capture additive effects of perturbation, dosage, and cell type.
  • Optimal Transport-based Methods: These methods attempt to map unperturbed populations to perturbed populations (e.g., [9, 27-29]).

    • Limitation: Their applicability has been limited by strong assumptions and poor scalability. For instance, CeT [9] uses Input Convex Neural Networks (ICNNs) to parameterize convex potentials for learning optimal transport maps.
  • Foundation Models for Single-Cell Omics: Recent advancements include large-scale transformer-based models trained on vast amounts of single-cell data, aiming for broad generalization:

    • scGPT [11]: A transformer-based foundation model that leverages generative pretraining on over 33 million cells for tasks like perturbation prediction.
    • GEARS [10]: Another model for predicting transcriptional outcomes of novel multi-gene perturbations.

3.3. Technological Evolution

The field of perturbation response prediction has evolved alongside advancements in single-cell technologies and deep learning.

  1. Early Statistical Models: Initially, methods relied on simpler statistical models or linear regressions to identify differentially expressed genes.

  2. Increased Data Scale: The advent of high-throughput screening technologies like pooled CRISPR perturbations (e.g., Perturb-seq [1-6]) has dramatically increased the scale and complexity of available perturbation datasets.

  3. Autoencoder-based Models: As deep learning gained traction, autoencoders (like CPA and scVI) were adapted to learn latent representations of cell states, aiming to capture underlying biological variations and predict perturbation effects in a lower-dimensional space.

  4. Foundation Models and Transformers: The most recent wave involves transformer architectures and foundation models (scGPT, GEARS). These models are trained on massive datasets (millions to hundreds of millions of cells) to learn general-purpose cell embeddings and biological rules, hoping to achieve zero-shot generalization to unseen contexts.

    This paper's work, STATE, represents the cutting edge of this evolution by:

  • Utilizing transformer architecture for its self-attention capabilities to model cellular heterogeneity at the population level.
  • Integrating pre-trained cell embeddings (SE) from an even larger observational dataset (167 million cells) to enhance robustness and transferability.
  • Employing Maximum Mean Discrepancy (MMD) to directly compare distributions of predicted and observed cell states, moving beyond single-cell reconstruction errors.
  • Providing a theoretical link to Optimal Transport, suggesting a principled way to learn cellular transformations.

3.4. Differentiation Analysis

STATE differentiates itself from previous approaches primarily through its multi-scale architecture, explicit handling of heterogeneity, and massive data leverage:

  • Explicitly Modeling Heterogeneity (vs. Neglecting or Averaging):

    • Unlike methods that assume within-population heterogeneity is negligible or simply map to average unperturbed cells (Perturbation Mean Baseline, Context Mean Baseline), ST uses self-attention over sets of cells. This allows it to explicitly model and learn residual, unannotated heterogeneity within cell populations, which is crucial for subtle perturbation effects.
    • Compared to pseudobulk models (which average responses), ST retains more information about cell-to-cell variability.
  • Robust Cell Embeddings for Transfer (vs. Context-Specific Representations):

    • Many models operate directly on gene expression profiles or learn embeddings that are specific to a dataset or context. STATE introduces SE, a pre-trained cell embedding model trained on 167 million observational cells, leveraging protein language models for robust gene representations.
    • These SE embeddings are designed to unify cell representation across diverse datasets, making STATE's predictions more robust to technical variation and significantly improving cross-dataset transfer learning and zero-shot generalization compared to models trained on highly variable genes (HVGs) or other foundation models like scGPT.
  • Distributional Learning with MMD (vs. Pointwise Prediction or Autoencoder Reconstruction):

    • Instead of predicting individual cell responses or relying solely on autoencoder reconstruction losses (CPA, scVI), ST is trained using Maximum Mean Discrepancy (MMD). MMD directly minimizes the distance between the distributions of predicted and observed cell sets. This distributional perspective is more biologically relevant for population-level responses and handles the destructive nature of scRNA-seq more naturally.
  • Scalability and Data Leverage:

    • STATE is explicitly designed to leverage giga-scale datasets (100M perturbed, 167M observational cells). The paper demonstrates that STATE's performance benefits significantly from increased data scale, outperforming other models that seem to be in a data-poor regime. This suggests STATE's architecture is inherently more capable of extracting insights from massive biological datasets.

      In essence, STATE's innovations lie in its integrated approach to handle complex cellular heterogeneity, generate universally applicable cell embeddings, learn distributional transformations, and scale efficiently to giga-cell datasets.

4. Methodology

4.1. Principles

The core idea behind STATE is to develop a multi-scale machine learning architecture that can predict how cells respond to perturbations. It achieves this by explicitly accounting for cellular heterogeneity—the natural variations among cells—both within a single experiment and across different experiments. The model is built on two complementary modules:

  1. State Transition model (ST): This module is responsible for learning the actual perturbation effects. Instead of trying to predict the response of a single cell (which is problematic given that scRNA-seq assays destroy cells), ST operates on sets of cells. It uses a transformer with self-attention to model how a distribution of control cells transforms into a distribution of perturbed cells. This set-based approach allows it to capture and predict changes in cellular heterogeneity beyond what can be explained by simple covariates like cell type.

  2. State Embedding model (SE): This module is designed to create robust and generalizable cell embeddings. SE is pre-trained on a massive amount of observational single-cell data (data from cells that were not intentionally perturbed) to learn representations that are less sensitive to technical noise and better capture underlying biological signals. These cell embeddings then serve as high-quality inputs for the ST module, enabling STATE to transfer knowledge across different datasets and experimental conditions more effectively.

    The multi-scale aspect refers to STATE's ability to operate on gene-level features, individual cell representations (via SE embeddings), and cell population transformations (via ST on sets of cells). By combining these, STATE aims to create a virtual cell model that is flexible, scalable, and capable of generalizing to novel contexts.

4.2. Core Methodology In-depth (Layer by Layer)

The STATE model is composed of the State Transition (ST) model and the State Embedding (SE) model. We will first describe the conceptual Data Generation Process and then detail each module.

4.2.1. Data Generation Process

The paper begins by framing the problem with a conceptual data generation process. The observed log-normalized perturbed expression state of a cell, denoted as XpX_p, is considered a random variable generated from an unobservable unperturbed state, X0X_0. X0X_0 itself is a random variable representing the cell's underlying expression state, drawn from a basal cell distribution Dbasal\mathcal{D}_{\mathrm{basal}} specific to a given set of covariates (e.g., cell line, batch condition).

The perturbation effect is modeled as: Xp=X0+Tp(X0)+ε,X0Dbasal X_p = X_0 + T_p(X_0) + \varepsilon, \quad X_0 \sim \mathcal{D}_{\mathrm{basal}} where:

  • Tp(X0)T_p(X_0): Represents the true effect caused by perturbation pp on the unperturbed cell state X0X_0.

  • ε\varepsilon: Denotes experiment-specific technical noise, assumed to be independent of X0X_0.

    Since single-cell transcriptomic measurements destroy the cell, X0X_0 is unobservable, making direct modeling of this equation infeasible. STATE therefore shifts to a distributional view, operating on the observable basal cell distribution Dbasal\mathcal{D}_{\mathrm{basal}} to predict the perturbed state, denoted as X^p\hat{X}_p.

This forms the basis of STATE's model: X^p  H(Dbasal) + T^p(Dbasal) + ε,εPεXp dX^p \begin{array} { r l } & { \hat{X}_p \ \sim \ H(\mathcal{D}_{\mathrm{basal}}) \ + \ \hat{T}_p(\mathcal{D}_{\mathrm{basal}}) \ + \ \varepsilon, \qquad \varepsilon \sim P_{\varepsilon} } \\ & { \quad \quad X_{\mathrm{p}} \ \stackrel{d}{\approx} \hat{X}_{\mathrm{p}} } \end{array} In this approximation:

  • T^p(Dbasal)\hat{T}_p(\mathcal{D}_{\mathrm{basal}}): The true effect of the perturbation Tp(X0)T_p(X_0) is now considered in the context of the entire basal population.
  • H(Dbasal)H(\mathcal{D}_{\mathrm{basal}}): Explicitly represents the biological heterogeneity inherent in the baseline population. This heterogeneity was implicitly removed when sampling X0DbasalX_0 \sim \mathcal{D}_{\mathrm{basal}} in the first equation but is made explicit here to reflect the distributional view.
  • X^p\hat{X}_{\mathrm{p}}: Is a distributional analogue of XpX_{\mathrm{p}}, allowing the model to predict the perturbed state based on observable population characteristics rather than unobservable individual cell states.

4.2.2. State Transition Model (ST)

ST is the core model for learning perturbation effects across populations of cells. It uses a transformer architecture with self-attention across cells to predict perturbation effects between control and perturbed cell distributions.

A dataset D\mathcal{D} of single-cell RNA-sequencing measurements is defined as: D={(x(i),pi,i,bi)}i=1N,x(i)RG. \mathcal{D} = \big\{ \big( \mathbf{x}^{(i)}, p_i, \ell_i, b_i \big) \big\}_{i=1}^N, \qquad \mathbf{x}^{(i)} \in \mathbb{R}^G. Here:

  • x(i)\mathbf{x}^{(i)}: Represents the log-normalized expression vector for cell ii, where GG is the number of genes.
  • pip_i: Perturbation label for cell ii, which can be a specific perturbation ID or a control state (ctrl\mathsf{ctrl}).
  • i\ell_i: Biological context or cell line label for cell ii.
  • bib_i: Optional batch effect label for cell ii.

4.2.2.1. Formation of Cell Sets

Cells are grouped into sets based on their biological context (\ell), perturbation (pp), and batch labels (bb): C,p,b = {x(i)D  i=, pi=p, bi=b}. \mathcal{C}_{\ell,p,b} \ = \ \big\{ \mathbf{x}^{(i)} \in \mathcal{D} \ \big| \ \ell_i = \ell, \ p_i = p, \ b_i = b \big\}. From these groups, fixed-size cell sets of size SS are formed. If a group C,p,b\mathcal{C}_{\ell,p,b} has N,p,bN_{\ell,p,b} cells, it is chunked into cell sets S,p,b(k)RS×GS_{\ell,p,b}^{(k)} \in \mathbb{R}^{S \times G}, where kk indexes the sets. If N,p,bN_{\ell,p,b} is not divisible by SS, the remaining cells form a smaller set, which is then padded to size SS by sampling additional cells with replacement from itself.

4.2.2.2. Training on Cell Sets

During training, each perturbed cell set S,p,b(k)S_{\ell,p,b}^{(k)} (where pctrlp \neq \mathsf{ctrl}) is paired with a corresponding control cell set. The control set is generated by a mapping function map()\mathtt{map}(\cdot) that samples SS control cells from the same cell line \ell and batch bb. map(S,p,b(k))=stack([x(i)]x(i)C,ctrl,b)RS×G \mathtt{map}(S_{\ell,p,b}^{(k)}) = \mathtt{stack}([\mathbf{x}^{(i)}]_{\mathbf{x}^{(i)} \sim \mathcal{C}_{\ell,\mathrm{ctrl},b}}) \in \mathbb{R}^{S \times G} This mapping function is crucial as it controls which sources of variation are explicitly accounted for. By conditioning on specific covariates, it reduces known sources of heterogeneity.

For a mini-batch of BB such set pairs, the following tensors are constructed:

  • Xtarget=stack([S1,p1,b1(k1),,SB,pB,bB(kB)])RB×S×G\mathbf{X}_{\mathrm{target}} = \mathsf{stack}([S_{\ell_1,p_1,b_1}^{(k_1)}, \ldots, S_{\ell_B,p_B,b_B}^{(k_B)}]) \in \mathbb{R}^{B \times S \times G} (observed perturbed cell sets).
  • Xctrl=stack([map(S1,p1,b1(k1)),,map(SB,pB,bB(kB)]))RB×S×G\mathbf{X}_{\mathrm{ctrl}} = \mathsf{stack}([\mathsf{map}(S_{\ell_1,p_1,b_1}^{(k_1)}), \ldots, \mathsf{map}(S_{\ell_B,p_B,b_B}^{(k_B)}])) \in \mathbb{R}^{B \times S \times G} (corresponding control cell sets).
  • ZpertRB×S×Dpert\mathbf{Z}_{\mathrm{pert}} \in \mathbb{R}^{B \times S \times D_{\mathrm{pert}}} (perturbation embeddings).
  • ZbatchRB×S×Dbatch\mathbf{Z}_{\mathrm{batch}} \in \mathbb{R}^{B \times S \times D_{\mathrm{batch}}} (optional batch covariates). Here, DpertD_{\mathrm{pert}} and DbatchD_{\mathrm{batch}} are the dimensionalities of the perturbation and batch embeddings, respectively. For STATE, one-hot encodings are used, so DpertD_{\mathrm{pert}} is the number of unique perturbations and DbatchD_{\mathrm{batch}} is the number of unique batch labels.

The ST model takes Xctrl\mathbf{X}_{\mathrm{ctrl}} (control cell sets) and Zpert\mathbf{Z}_{\mathrm{pert}} (perturbation embeddings) as input and learns to predict Xtarget\mathbf{X}_{\mathrm{target}} (the corresponding perturbed states). This means ST learns to transform control cell populations into their predicted perturbed states.

4.2.2.3. Neural Network Modules

ST uses specialized encoders to map cellular expression profiles, perturbation labels, and batch labels into a shared hidden dimension dhd_h, which serves as the input to the transformer.

  • Control Cell Encoder: Each log-normalized expression vector x(i)RG\mathbf{x}^{(i)} \in \mathbb{R}^G is mapped to an embedding via a 4-layer MLP (Multi-Layer Perceptron) with GELU (Gaussian Error Linear Unit) activations. This MLP, denoted as fcellf_{\mathrm{cell}}, is applied to each cell independently across the entire control tensor. Hcell=fcell(Xctrl)RB×S×dh \mathbf{H}_{\mathrm{cell}} = f_{\mathrm{cell}}(\mathbf{X}_{\mathrm{ctrl}}) \in \mathbb{R}^{B \times S \times d_h} This transforms the input shape from (B×S×GB \times S \times G) to (B×S×dhB \times S \times d_h).

  • Perturbation Encoder: Perturbation labels are encoded into the same embedding dimension dhd_h. For one-hot encoded perturbations, the input vector is passed through a 4-layer MLP with GELU activations: Hpert = fpert(Zpert)  RB×S×dh \mathbf{H}_{\mathrm{pert}} \ = \ f_{\mathrm{pert}}(\mathbf{Z}_{\mathrm{pert}}) \ \in \ \mathbb{R}^{B \times S \times d_h} This transforms the input shape (B×S×DpertB \times S \times D_{\mathrm{pert}}) to (B×S×dhB \times S \times d_h). Note that the perturbation embedding is the same for all cells within the same set of a given batch. If perturbations are represented by continuous features (e.g., molecular descriptors), these embeddings are used directly for Hpert\mathbf{H}_{\mathrm{pert}}, and dhd_h is set to DpertD_{\mathrm{pert}}.

  • Batch Encoder: To account for technical batch effects, batch labels bi{1,,B}b_i \in \{1, \ldots, B\} are encoded into embeddings of dimension dhd_h: Hbatch=fbatch(Zbatch)RB×S×dh \mathbf{H}_{\mathrm{batch}} = f_{\mathrm{batch}}(\mathbf{Z}_{\mathrm{batch}}) \in \mathbb{R}^{B \times S \times d_h} Here, fbatchf_{\mathrm{batch}} is an embedding layer, transforming the input shape from (B×S×DbatchB \times S \times D_{\mathrm{batch}}) to (B×S×dhB \times S \times d_h).

  • Transformer Inputs and Outputs: The final input to ST is constructed by summing the control cell embeddings with the perturbation and batch embeddings: H=Hcell+Hpert+Hbatch \mathbf{H} = \mathbf{H}_{\mathrm{cell}} + \mathbf{H}_{\mathrm{pert}} + \mathbf{H}_{\mathrm{batch}} This composite representation is then passed to the transformer backbone fSTf_{\mathrm{ST}} to model perturbation effects across the cell set. The output is computed as: O=H+fST(H) \mathbf{O} = \mathbf{H} + f_{\mathrm{ST}}(\mathbf{H}) where ORB×S×dh\mathbf{O} \in \mathbb{R}^{B \times S \times d_h} represents the final output. This residual connection formulation encourages the transformer fSTf_{\mathrm{ST}} to learn perturbation effects as differences or residuals to the input representation H\mathbf{H}.

  • Gene Reconstruction Head: When ST operates directly in gene expression space, a gene reconstruction head maps the transformer's output O\mathbf{O} back to the gene expression space. This is achieved using a linear projection layer, applied independently to the dhd_h-dimensional hidden representation of each cell. For a specific batch bb, where O(b)RS×dh\mathbf{O}^{(b)} \in \mathbb{R}^{S \times d_h} is the transformer output for cell bb, the predicted gene expression X^target(b)RS×G\hat{\mathbf{X}}_{\mathrm{target}}^{(b)} \in \mathbb{R}^{S \times G} is given by: X^target(b)=frecon(O(b))=O(b)Wrecon+brecon \hat{\mathbf{X}}_{\mathrm{target}}^{(b)} = f_{\mathrm{recon}}(\mathbf{O}^{(b)}) = \mathbf{O}^{(b)}\mathbf{W}_{\mathrm{recon}} + \mathbf{b}_{\mathrm{recon}} Here, WreconRdh×G\mathbf{W}_{\mathrm{recon}} \in \mathbb{R}^{d_h \times G} and breconRG\mathbf{b}_{\mathrm{recon}} \in \mathbb{R}^G are learnable parameters. This operation transforms the hidden representations, yielding reconstructed log-transformed gene expression values for each cell in the batch.

4.2.2.4. Learning Perturbation Effects with Maximum Mean Discrepancy

ST is trained to minimize the discrepancy between predicted (X^target\hat{\mathbf{X}}_{\mathrm{target}}) and observed (Xtarget\mathbf{X}_{\mathrm{target}}) transcriptomic responses. This discrepancy is quantified using the Maximum Mean Discrepancy (MMD) [57].

For each mini-batch element bb, ST considers the set of SS predicted cell expression vectors (X^target(b)\hat{\mathbf{X}}_{\mathrm{target}}^{(b)}) and observed cell expression vectors (Xtarget(b)\mathbf{X}_{\mathrm{target}}^{(b)}). The squared MMD between these two sets is computed as: MMD2(X^target(b),Xtarget(b))=1S2i=1Sj=1S[k(x^(i),x^(j))+k(x(i),x(j))2k(x^(i),x(j))] \mathrm{MMD}^2 \bigl( \hat{\mathbf{X}}_{\mathrm{target}}^{(b)}, \mathbf{X}_{\mathrm{target}}^{(b)} \bigr) = \frac{1}{S^2} \sum_{i=1}^S \sum_{j=1}^S \Big[ k(\hat{\mathbf{x}}^{(i)}, \hat{\mathbf{x}}^{(j)}) + k(\mathbf{x}^{(i)}, \mathbf{x}^{(j)}) - 2k(\hat{\mathbf{x}}^{(i)}, \mathbf{x}^{(j)}) \Big] where:

  • x^(i),x^(j)\hat{\mathbf{x}}^{(i)}, \hat{\mathbf{x}}^{(j)}: Predicted cell expression vectors from X^target(b)\hat{\mathbf{X}}_{\mathrm{target}}^{(b)}.
  • x(i),x(j)\mathbf{x}^{(i)}, \mathbf{x}^{(j)}: Observed cell expression vectors from Xtarget(b)\mathbf{X}_{\mathrm{target}}^{(b)}.
  • k(,)k(\cdot, \cdot): The kernel function. The three terms represent: (1) similarity within the predicted set, (2) similarity within the observed set, and (3) cross-similarity between predicted and observed sets.

The paper uses the energy distance kernel: k(u,v)=uv2 k(\mathbf{u}, \mathbf{v}) = -\|\mathbf{u} - \mathbf{v}\|_2 This kernel is implemented via the geomloss library [59].

For a training mini-batch of BB cell sets, the batch-averaged MMD loss is: LMMD(X^target,Xtarget)=1Bb=1BMMD2(X^target(b),Xtarget(b)) \mathcal{L}_{\mathrm{MMD}} ( \hat{\mathbf{X}}_{\mathrm{target}}, \mathbf{X}_{\mathrm{target}} ) = \frac{1}{B} \sum_{b=1}^B \mathrm{MMD}^2 \big( \hat{\mathbf{X}}_{\mathrm{target}}^{(b)}, \mathbf{X}_{\mathrm{target}}^{(b)} \big) Minimizing this loss encourages the model to generate sets of perturbed cell expression vectors whose overall statistical properties match those of the observed cell sets, as captured by the MMD.

The total loss for ST is: Ltotal=LMMD(X^target,Xtarget) \mathcal{L}_{\mathrm{total}} = \mathcal{L}_{\mathrm{MMD}} ( \hat{\mathbf{X}}_{\mathrm{target}}, \mathbf{X}_{\mathrm{target}} )

4.2.2.5. Training ST in Embedding Spaces

ST offers the flexibility to be trained either directly in gene expression space (using top 2,000 highly variable genes (HVGs), referred to as ST+HVGST+HVG) or in a specified embedding space.

When trained in an embedding space:

  • The dimensionality of the embedding space is EE, where typically EGE \ll G.

  • The input tensors Xtarget\mathbf{X}_{\mathrm{target}} and Xctrl\mathbf{X}_{\mathrm{ctrl}} become Xtargetemb\mathbf{X}_{\mathrm{target}}^{\mathrm{emb}} and Xctrlemb\mathbf{X}_{\mathrm{ctrl}}^{\mathrm{emb}}, with dimension B×S×EB \times S \times E.

  • The control cell encoder fcellf_{\mathrm{cell}} is modified to transform from (B×S×EB \times S \times E) to (B×S×dhB \times S \times d_h).

  • The gene reconstruction head freconf_{\mathrm{recon}} is modified to transform from (B×S×dhB \times S \times d_h) to (B×S×EB \times S \times E), producing X^targetemb\hat{\mathbf{X}}_{\mathrm{target}}^{\mathrm{emb}}.

  • All other encoders and the transformer remain unchanged.

    To recover the original gene expression Xtarget\mathbf{X}_{\mathrm{target}} from the embedding space, an additional decoder head fdecodef_{\mathrm{decode}} is trained. This is a multi-layer MLP with dropout that maps from the embedding space back to the full gene expression space: X^target=fdecode(X^targetemb) \hat{\mathbf{X}}_{\mathrm{target}} = f_{\mathrm{decode}} ( \hat{\mathbf{X}}_{\mathrm{target}}^{\mathrm{emb}} ) This transforms the predicted embeddings from (B×S×EB \times S \times E) to (B×S×GB \times S \times G).

The total loss when training in embedding space is a weighted combination of MMD losses in both the embedding and gene expression spaces: Ltotal=LMMD(X^targetemb,Xtargetemb)+0.1LMMD(X^target,Xtarget) \mathcal{L}_{\mathrm{total}} = \mathcal{L}_{\mathrm{MMD}} ( \hat{\mathbf{X}}_{\mathrm{target}}^{\mathrm{emb}}, \mathbf{X}_{\mathrm{target}}^{\mathrm{emb}} ) + 0.1 \cdot \mathcal{L}_{\mathrm{MMD}} ( \hat{\mathbf{X}}_{\mathrm{target}}, \mathbf{X}_{\mathrm{target}} ) The expression loss is down-weighted by a factor of 0.1 to balance the terms and prioritize learning perturbation effects in the embedding space.

4.2.3. State Embedding Model (SE)

SE is a self-supervised model pre-trained to learn high-quality cell representations from single-cell RNA sequencing data. These embeddings then serve as inputs to ST, enabling more robust transfer across datasets.

4.2.3.1. Gene Representation via Protein Language Models

The SE model incorporates gene representations obtained from pretrained protein language models.

  • Gene embeddings gjR5120\mathbf{g}_j \in \mathbb{R}^{5120} are computed using ESM-2 (esm2_t48_15B_UR50D [60]). This involves averaging per-amino-acid embeddings for each protein-coding transcript in a gene, and then averaging across all transcripts for that gene. This captures evolutionary and functional relationships between genes.
  • These gene embeddings are then projected into the model's embedding dimension hh via a learnable encoder: g~j=SiLU(LayerNorm(gjWg+bg)) \tilde{\mathbf{g}}_j = \mathrm{SiLU} ( \mathrm{LayerNorm} ( \mathbf{g}_j\mathbf{W}_g + \mathbf{b}_g ) ) where:
    • WgR5120×h\mathbf{W}_g \in \mathbb{R}^{5120 \times h} and bgRh\mathbf{b}_g \in \mathbb{R}^h are learnable parameters.
    • SiLU (Sigmoid Linear Unit) is the activation function [63].
    • LayerNorm is Layer Normalization.

4.2.3.2. Cell Representation

Each cell ii is represented as a sequence of its most highly expressed genes.

  • An "expression set" is constructed by selecting the top L=2048L = 2048 genes ranked by log fold expression level.
  • This set is augmented with two special tokens: c~(i)=[zcls,g~1(i),g~2(i),\hdots,g~L(i),zds]R(L+2)×h \tilde{\mathbf{c}}^{(i)} = [ \mathbf{z}_{\mathrm{cls}}, \tilde{\mathbf{g}}_1^{(i)}, \tilde{\mathbf{g}}_2^{(i)}, \hdots, \tilde{\mathbf{g}}_L^{(i)}, \mathbf{z}_{\mathrm{ds}} ] \in \mathbb{R}^{(L+2) \times h} where:
    • zclsRh\mathbf{z}_{\mathrm{cls}} \in \mathbb{R}^h: A learnable classification token used to aggregate cell-level information.
    • zdsRh\mathbf{z}_{\mathrm{ds}} \in \mathbb{R}^h: A learnable dataset token to help disentangle dataset-specific effects.
    • g~(i)\tilde{\mathbf{g}}_{\ell}^{(i)}: The projected embedding of the \ell-th most highly expressed gene in cell ii.
  • Gene selection is cell-specific. If a cell expresses fewer than LL genes, the expression set is padded to length LL by randomly sampling from unexpressed genes to maintain fixed-length input.

4.2.3.3. Expression-Aware Embeddings

To incorporate expression values directly, SE uses an expression embedding scheme inspired by soft binning [12]. For the \ell-th most expressed gene j(i)j_{\ell}^{(i)} in cell ii, with expression value xj(i)(i)\mathbf{x}_{j_{\ell}^{(i)}}^{(i)}, a soft bin assignment is computed: α(i)=Softmax(MLPcount(xj(i)(i)))R10e(i)=k=110α,k(i)bk \begin{array} { r l } & { \pmb{\alpha}_{\ell}^{(i)} = \mathrm{Softmax} ( \mathrm{MLP}_{\mathrm{count}} ( \mathbf{x}_{j_{\ell}^{(i)}}^{(i)} ) ) \in \mathbb{R}^{10} } \\ & { \mathbf{e}_{\ell}^{(i)} = \displaystyle \sum_{k=1}^{10} \pmb{\alpha}_{\ell,k}^{(i)} \mathbf{b}_k } \end{array} where:

  • MLPcount:RR10\mathrm{MLP}_{\mathrm{count}} : \mathbb{R} \to \mathbb{R}^{10}: Consists of two linear layers (dimensions 1512101 \to 512 \to 10) with LeakyReLU activation.
  • {bk}k=110\{\mathbf{b}_k\}_{k=1}^{10}: A set of learnable embeddings of dimension hh. The resulting expression embeddings e(i)\mathbf{e}_{\ell}^{(i)} are added to the corresponding gene identity embeddings g~(i)\tilde{\mathbf{g}}_{\ell}^{(i)}: g(i)=g~(i)+e(i) \mathbf{g}_{\ell}^{(i)} = \tilde{\mathbf{g}}_{\ell}^{(i)} + \mathbf{e}_{\ell}^{(i)}

4.2.3.4. Transformer Encoding

The input expression set (composed of expression-aware gene embeddings and special tokens) is passed through the transformer encoder fSEf_{\mathrm{SE}}: E(i)=fSE([zcls,g1(i),g2(i),,gL(i),zds])R(L+2)×h \mathbf{E}^{(i)} = f_{\mathrm{SE}} ( [ \mathbf{z}_{\mathrm{cls}}, \mathbf{g}_1^{(i)}, \mathbf{g}_2^{(i)}, \dots, \mathbf{g}_L^{(i)}, \mathbf{z}_{\mathrm{ds}} ] ) \in \mathbb{R}^{(L+2) \times h} Here, g(i)\mathbf{g}_{\ell}^{(i)} is as defined in the previous step. Note that expression encodings are set to zero for the CLS and DS tokens.

  • The cell embedding is extracted from the CLS token at position 0 and normalized: ecls(i)=LayerNorm(E0(i))Rh \mathbf{e}_{\mathrm{cls}}^{(i)} = \mathrm{LayerNorm} ( \mathbf{E}_0^{(i)} ) \in \mathbb{R}^h This embedding serves as a summary representation of the cell's transcriptomic state.
  • Similarly, the dataset representation is extracted from the DS token at position L+1L+1: eds(i)=LayerNorm(EL+1(i))Rh \mathbf{e}_{\mathrm{ds}}^{(i)} = \mathrm{LayerNorm} ( \mathbf{E}_{L+1}^{(i)} ) \in \mathbb{R}^h This embedding captures and accounts for dataset-specific effects.

The final cell embedding zcell(i)\mathbf{z}_{\mathrm{cell}}^{(i)} is the concatenation of these two quantities: zcell(i)=[ecls(i),fproj(eds(i))]Rh+10 \mathbf{z}_{\mathrm{cell}}^{(i)} = [ \mathbf{e}_{\mathrm{cls}}^{(i)}, f_{\mathrm{proj}} ( \mathbf{e}_{\mathrm{ds}}^{(i)} ) ] \in \mathbb{R}^{h+10} where fprojf_{\mathrm{proj}} projects eds(i)\mathbf{e}_{\mathrm{ds}}^{(i)} to a 10-dimensional space. This cell embedding zcell\mathbf{z}_{\mathrm{cell}} serves as the input representation for individual cells in ST.

4.2.3.5. Pretraining Objectives

SE is trained with a self-supervised learning framework using two objectives: gene expression prediction and an auxiliary dataset classification task.

  • Gene Expression Prediction: The model predicts expression values for a selected set of 1,280 genes per cell. Target genes are drawn from three categories to ensure coverage across the expression dynamic range:

    • P(i)\mathcal{P}^{(i)}: 512 highly expressed genes (from the top LL genes).
    • N(i)\mathcal{N}^{(i)}: 512 unexpressed genes (randomly sampled from genes not in the top LL).
    • R\mathcal{R}: 256 genes randomly sampled from the full gene set, shared for all cells in the batch. This strategy encourages reconstruction fidelity for both expressed and silent genes.
  • Expression Prediction Decoder: An MLP decoder combines multiple sources of information to predict gene expression: x^j(i)=MLPdec([zcell(i);g~j;r(i)]) \hat{\mathbf{x}}_j^{(i)} = \mathrm{MLP}_{\mathrm{dec}} \big( [ \mathbf{z}_{\mathrm{cell}}^{(i)}; \tilde{\mathbf{g}}_j; r^{(i)} ] \big) where:

    • zcell(i)Rh+10\mathbf{z}_{\mathrm{cell}}^{(i)} \in \mathbb{R}^{h+10}: The learned cell embedding.
    • g~jRh\tilde{\mathbf{g}}_j \in \mathbb{R}^h: The embedding of the target gene.
    • r(i)Rr^{(i)} \in \mathbb{R}: A scalar read depth indicator (mean log expression of expressed genes in the input expression set). These are concatenated and passed through MLPdec\mathrm{MLP}_{\mathrm{dec}} (two skip-connected blocks and a linear output layer) to predict the log expression.

    For each cell ii in a batch, let Y^(i)\hat{\mathbf{Y}}^{(i)} be the predicted and Y(i)\mathbf{Y}^{(i)} be the true expression values for the 1,280 target genes. The gene-level loss is: Lgene=1Bb=1BY^(b)Y(b)2 \mathcal{L}_{\mathrm{gene}} = \frac{1}{B} \sum_{b=1}^B \| \hat{\mathbf{Y}}^{(b)} - \mathbf{Y}^{(b)} \|_2 This measures the similarity between predicted and true gene expression patterns within each cell.

    To capture variation across cells, a cell-level loss is computed using the shared subset of genes R\mathcal{R}. Let S^(i)\hat{\mathbf{S}}^{(i)} and S(i)\mathbf{S}^{(i)} be the predicted and true expression values for genes in R\mathcal{R} in cell ii. These are stacked across cells and transposed. The cell-level loss is: Lcell=1Rr=1RS^(r)S(r)2 \mathcal{L}_{\mathrm{cell}} = \frac{1}{|\mathcal{R}|} \sum_{r=1}^{|\mathcal{R}|} \| \hat{\mathbf{S}}^{\prime(r)} - \mathbf{S}^{\prime(r)} \|_2 This measures the similarity between predicted and true gene expression across cells for each gene in R\mathcal{R}.

    The final training loss for expression prediction combines both axes: Lexpression=λ1Lgene+λ2Lcell \mathcal{L}_{\mathrm{expression}} = \lambda_1 \mathcal{L}_{\mathrm{gene}} + \lambda_2 \mathcal{L}_{\mathrm{cell}}

  • Dataset Classification Modeling: An auxiliary dataset prediction task helps disentangle technical batch effects from biological variation. Using the DS token embedding eds(i)\mathbf{e}_{\mathrm{ds}}^{(i)}, the model predicts the dataset of origin: d^(i)=MLPdataset(eds(i)) \hat{d}^{(i)} = \mathrm{MLP}_{\mathrm{dataset}} ( \mathbf{e}_{\mathrm{ds}}^{(i)} ) The loss for this task is CrossEntropy: Ldataset=1BbBCrossEntropy(d^(b),d(b)) \mathcal{L}_{\mathrm{dataset}} = \frac{1}{B} \sum_b^B \mathrm{CrossEntropy} ( \hat{d}^{(b)}, d^{(b)} ) where d(b)d^{(b)} is the true dataset label. This encourages the model to pool dataset-specific information into this token, disentangling it from true biological signal.

  • Total Loss: The SE model is trained using a combination of both losses: L=Lexpression+Ldataset. \mathcal{L} = \mathcal{L}_{\mathrm{expression}} + \mathcal{L}_{\mathrm{dataset}}.

4.2.4. Theoretical Analysis of ST and Optimal Transport

The paper also provides a theoretical analysis of ST's capacity to learn optimal transport (OT) mappings between cellular distributions.

  • Background on Optimal Transport: OT aims to find a mapping or coupling that minimizes a fixed cost between two probability distributions [74, 75]. Neural OT uses neural networks to parameterize and solve OT problems [9, 27].

  • ST as a Transformation Learner: While ST doesn't explicitly solve an OT problem, it learns a transformation that aligns unperturbed and perturbed cell distributions. The authors connect ST to recent work showing that transformers can implement gradient descent for OT through engineered prompts [76, 77].

  • Asymptotic Behavior and Solution Family: The analysis focuses on the asymptotic setting where the cell set size (SS) tends to infinity. In this limit, ST's self-attention mechanism processes information from the entire distribution. The model defines an operator Fpert,batch\mathbf{F}_{\mathrm{pert,batch}} that maps from the control distribution Dctrl\mathcal{D}_{\mathrm{ctrl}} to the predicted perturbed distribution D^pert\hat{\mathcal{D}}_{\mathrm{pert}}. X^pert,s(b)=[(freconfcell+fST,Dctrlpert,batch)(Xctrl(b))]s:=Fpert,batch(Xctrl,s(b)) \hat{\mathbf{X}}_{\mathrm{pert},s}^{(b)} = \left[ \left( f_{\mathrm{recon}} \circ f_{\mathrm{cell}} + f_{\mathrm{ST}, \mathcal{D}_{\mathrm{ctrl}}}^{\mathrm{pert,batch}} \right) ( \mathbf{X}_{\mathrm{ctrl}}^{(b)} ) \right]_s := F_{\mathrm{pert,batch}} ( \mathbf{X}_{\mathrm{ctrl},s}^{(b)} ) Here, Fpert,batchF_{\mathrm{pert,batch}} represents the transformation learned by ST for a given perturbation and batch.

  • Lemma 1 (MMD and Distributional Matching): The paper states that if ST achieves zero empirical MMD with the energy kernel as SS \to \infty, then the predicted distribution D^pert\hat{\mathcal{D}}_{\mathrm{pert}} equals the true perturbed distribution Dpert\mathcal{D}_{\mathrm{pert}} with probability 1, and vice versa. MMD^(X^pert(b),Xpert(b))=0    D^pert=Dpert. \widehat{\mathrm{MMD}} ( \hat{\mathbf{X}}_{\mathrm{pert}}^{(b)}, \mathbf{X}_{\mathrm{pert}}^{(b)} ) = 0 \quad \implies \quad \hat{\mathcal{D}}_{\mathrm{pert}} = \mathcal{D}_{\mathrm{pert}}. The empirical MMD (MMD^\widehat{\mathrm{MMD}}) is a V-statistic [79, 80] and converges almost surely to the theoretical MMD (MMD\mathrm{MMD}). With the energy kernel, MMD2D2\mathrm{MMD}^2 \equiv D^2, and zero energy distance implies equal distributions [81].

  • Theorem 2 (Optimal Transport Mapping within the Solution Family of STATE): Under certain regularity conditions (densities are absolutely continuous and bounded, support sets are strictly convex and compact with C2\mathcal{C}^2 boundary), the unique continuous optimal transport map TT from Dctrl\mathcal{D}_{\mathrm{ctrl}} to Dpert\mathcal{D}_{\mathrm{pert}} (associated with the squared distance cost c(x,y)=xy2c(x,y) = \|x-y\|^2) is contained within the solution family F\mathcal{F} of ST (i.e., mappings FF such that F(Dctrl)=DpertF(\mathcal{D}_{\mathrm{ctrl}}) = \mathcal{D}_{\mathrm{pert}}) with probability 1. This implies that ST can learn the optimal transport map if it perfectly minimizes the MMD loss. However, the solution family F^\hat{\mathcal{F}} (the set of functions ST can learn) contains infinite additional elements beyond TT, meaning the optimal transport solution is not guaranteed to be learned.

  • Theorem 3 (Constrained ST Model for Unique OT Map): If explicit constraints are imposed on the Jacobian of the ST model's mapping (specifically, that the Jacobian is symmetric and semi-positive definite), then the constrained solution family F^\hat{\mathcal{F}}^* uniquely contains the continuous OT mapping TT with probability 1. This condition (T=ψT = \nabla \psi, where ψ\psi is a convex function) is the definition of an optimal transport map for the quadratic cost, as per Brenier's theorem [83].

This theoretical analysis suggests that ST has the inherent capacity to learn fundamental cellular transformations akin to optimal transport, especially under certain architectural or regularization constraints. The paper hypothesizes that the implicit bias [85-87] of gradient descent in training ST may naturally lead to solutions resembling optimal transport mappings by favoring "minimal shifts."

4.3. ST Training Details

The ST model is implemented using PyTorch Lightning with distributed data parallel (DDP) training and automatic mixed precision (AMP).

  • Hyperparameters: Key hyperparameters include cell_set_size, hidden_dim, n_encoder_layers, n_decoder_layers, batch_encoder (boolean), transformer_backbone_key (LLaMA or GPT2), attn_heads, and params.
  • Transformer Backbone: Uses either a LLaMA [66] or GPT2 [67] backbone from HuggingFace. GPT2 is preferred for sparser datasets like Replogle-Nadig due to LayerNorm's properties in low-data regimes [68].
  • Attention: All models use bi-directional attention.
  • Positional Encodings: No positional encodings are used since cell order within each set is arbitrary.
  • Dropout: Not applied within the transformer.
  • Initialization: Before training, weights are initialized using Kaiming Uniform [69], except for the transformer backbone which is initialized from N(0,0.022)\mathcal{N}(0, 0.02^2).
  • Fine-tuning: For fine-tuning tasks, a new ST model is initialized with pretrained weights. Perturbation encoder (f_{\mathrm{pert}}) is reinitialized for transfer across perturbation modalities. If ST uses cell embeddings, the gene decoder (f_{\mathrm{decode}}) is also reinitialized to adapt to gene coverage differences. Other components (fcellf_{\mathrm{cell}}, fbatchf_{\mathrm{batch}}, fSTf_{\mathrm{ST}}) retain pretrained weights and are fine-tuned.

4.4. SE Training Details

The SE model is a 600M parameter encoder-decoder model.

  • Architecture: The encoder has 16 transformer layers, each with 16 attention heads and hidden dimension h=2048h=2048. Each layer uses pre-normalization, a feed-forward network expanding to 3×h3 \times h, and GELU activation. Dropout with probability 0.1 is applied to attention and feed-forward layers. The decoder is an MLP.
  • Optimizer: AdamW optimizer [70] with a maximum learning rate of 10510^{-5}, weight decay of 0.01, and gradient clipping with zclip [71].
  • Learning Rate Schedule: Linear warmup for the first 3% of steps, followed by cosine annealing to 30% of the maximum learning rate.
  • Initialization: All SE weights initialized from Kaiming Uniform.
  • Training Scale: Trained on 14,420 AnnData files spanning 167 million human cells across Arc scBaseCount [32], CZ CELLxGENE [31], and Tahoe-100M [30] datasets for 4 epochs.
  • Data Leakage Prevention: Datasets split into separate training and validation sets at the dataset level.
  • Efficiency: Utilizes Flash Attention 2 [72] and mixed precision (bf16) training [73] for efficient, large-scale training.
  • Hardware: Distributed across 4 compute nodes, each with 8 NVIDIA H100 GPUs.
  • Batch Size: Effective batch size of 3,072 (per-device batch size of 24, gradient accumulation over 4 steps).

5. Experimental Setup

5.1. Datasets

5.1.1. Datasets Used for ST Training

STATE was evaluated on several large-scale single-cell perturbation datasets:

  • Tahoe-100M [30]: A large dataset focused on drug-based perturbations.

  • Replogle-Nadig [4, 43]: Contains genome-scale genetic perturbations.

  • Parse-PBMC [42]: Features cytokine signaling perturbations.

  • Jiang [64]: Another genetic perturbation dataset.

  • McFaline [65]: Contains chemical genomics data.

  • Srivatsan [41]: Another chemical transcriptomics dataset.

    The following are the results from Table 1 of the original paper:

    Dataset # of Cells # of Perturbations # of Contexts
    Replogle-Nadig 624,158 1,677 4
    Jiang 234,845 24 30
    Srivatsan 762,795 189 3
    Mcfaline-Figueroa 354,758 122 3
    Tahoe-100M 100,648,790 1138 50
    Parse-PBMC 9,697,974 90 12/18 (Donors)/(Cell Types)

Preprocessing for ST Training:

  • All datasets were filtered to retain measurements for 19,790 human protein-coding Ensembl genes.
  • Normalization: Subsequently normalized to a total UMI (Unique Molecular Identifier) depth of 10,000.
  • Log-transformation: Raw count data were log-transformed using scanpy.pp.log1p.
  • Highly Variable Genes (HVGs): For analyses involving HVGs, the top 2,000 HVGs were identified per dataset using scanpy.pp.highly_variable_genes. These log-transformed HVG values were used as gene-level features.
  • PCA Embeddings: PCA (Principal Component Analysis) embeddings of cells were computed using scanpy.pp.pca.

5.1.2. Additional preprocessing for genetic perturbation datasets

Specific filtering steps were applied to genetic perturbation datasets to ensure high knockdown efficacy:

  1. Perturbation-level filtering: Only perturbations (excluding controls) whose average knockdown efficiency resulted in a residual expression 0.30\leq 0.30 were retained.
  2. Cell-level filtering: Within these selected perturbations, only cells individually meeting a stricter knockdown threshold (residual expression 0.50\leq 0.50) were kept.
  3. Minimum cell count: Any perturbations with fewer than 30 remaining valid cells were dropped, while control cells were always preserved.

5.1.3. Datasets Used for SE Training

The SE model was trained on an even larger corpus of observational data to learn robust cell embeddings.

  • SE was trained on 167 million human cells across three datasets: Arc scBaseCount [32], CZ CELLxGENE [31], and Tahoe-100M [30].

    The following are the results from Table 2 of the original paper:

    Dataset # Training Cells # Validation Cells
    Arc scBaseCount [32] 71,676,369 4,137,674
    CZ CellXGene [31] 59,233,790 6,500,519
    Tahoe-100M [30] 36,157,383 2,780,587

Preprocessing for SE Training:

  • To avoid data leakage in context generalization benchmarks, SE was trained on 20 Tahoe-100M cell lines separate from five held-out cell lines used in benchmarking.
  • scBaseCount data was filtered to include only cells with at least 1,000 non-zero expression measurements and 2,000 UMIs per cell.
  • A subset of AnnData files was reserved for computing validation loss.

5.2. Evaluation Metrics

The evaluation of STATE focuses on its ability to discriminate perturbation effects, accurately identify differentially expressed genes, and generalize across contexts.

5.2.1. Perturbation Evaluation Metrics

  • Perturbation Discrimination Score (PDiscNorm):
    1. Conceptual Definition: This metric assesses how well a model's predicted pseudobulk (average) expression profile for a perturbation matches the true pseudobulk profile for that perturbation, compared to other perturbations. A perfect score means the true profile is uniquely closest to its prediction.
    2. Mathematical Formula: For any perturbation tt, let pˉt\bar{p}_t denote the predicted pseudobulked expression and ptp_t denote the observed pseudobulked expression. Using a distance metric d(,)d(\cdot, \cdot) (Manhattan or Euclidean distance in this case), rtr_t is defined as: rt = tt1{d(pˉt,pt) < d(pˉt,pt)} r_t \ = \ \sum_{t' \neq t} \mathbf{1} \{ d(\bar{p}_t, p_{t'}) \ < \ d(\bar{p}_t, p_t) \} The per-perturbation score PDisc_t`` is: PDisct=rtT \mathrm{PDisc}_t = \frac{r_t}{T} The overall PDisc is the mean: PDisc=1Tt=1TPDisct \mathrm{PDisc} = \frac{1}{T} \sum_{t=1}^T \mathrm{PDisc}_t The normalized inverse perturbation discrimination score is reported: PDiscNorm=12PDisc \mathrm{PDiscNorm} = 1 - 2\mathrm{PDisc}
    3. Symbol Explanation:
      • tt: An index for a specific perturbation.
      • TT: The total number of distinct perturbations.
      • pˉt\bar{p}_t: The model's predicted pseudobulked expression for perturbation tt.
      • ptp_t: The true (observed) pseudobulked expression for perturbation tt.
      • d(,)d(\cdot, \cdot): A distance function (e.g., Manhattan or Euclidean distance) between expression profiles.
      • 1{}\mathbf{1}\{\cdot\}: An indicator function that is 1 if the condition inside is true, and 0 otherwise.
      • rtr_t: The number of other perturbations (ttt' \neq t) whose true pseudobulked expression ptp_{t'} is closer to the predicted pˉt\bar{p}_t than the correct true profile ptp_t.
      • PDisct\mathrm{PDisc}_t: The normalized rank of the true perturbation profile among all profiles for a given prediction. It ranges from [0, 1), with 0 being a perfect match.
      • PDisc\mathrm{PDisc}: The average PDisc_t`` across all perturbations.
      • PDiscNorm\mathrm{PDiscNorm}: The final reported score, normalized so that a random predictor gets a score of 0.0, and a perfect predictor gets a score of 1.0. Higher values indicate better discrimination.

5.2.2. Differential Expression

  • DE Overlap Accuracy:

    1. Conceptual Definition: This metric quantifies the overlap between the set of differentially expressed genes (DEGs) identified by the model's predictions and the DEGs identified from the true observed data. It measures how accurately the model can pinpoint the genes whose expression changes.
    2. Mathematical Formula: For each perturbation tt, Gt,true(k)G_{t, \mathrm{true}}^{(k)} denotes the set of top kk true DEGs, and Gt,pred(k)G_{t, \mathrm{pred}}^{(k)} denotes the set of top kk predicted DEGs. The overlap is computed as: Overlapt,k = Gt,true(k)Gt,pred(k)k \mathrm{Overlap}_{t,k} \ = \ \frac{ | G_{t, \mathrm{true}}^{(k)} \cap G_{t, \mathrm{pred}}^{(k)} | }{k} When k=Nk=N, NN represents the total number of DEGs in the true set.
    3. Symbol Explanation:
      • tt: An index for a specific perturbation.
      • kk: The number of top DEGs considered (e.g., 50, 100, 200, or NN for all DEGs).
      • Gt,true(k)G_{t, \mathrm{true}}^{(k)}: The set of top kk differentially expressed genes identified from the true data for perturbation tt. Genes are filtered by adjusted p-value (<0.05<0.05) and ranked by absolute log fold change.
      • Gt,pred(k)G_{t, \mathrm{pred}}^{(k)}: The set of top kk differentially expressed genes identified from the model's predicted data for perturbation tt.
      • AB|A \cap B|: The size of the intersection between set A and set B (i.e., the number of genes common to both the true and predicted DEG sets).
      • Overlapt,k\mathrm{Overlap}_{t,k}: The fraction of overlap for perturbation tt at a given kk. Higher values indicate better accuracy.
  • Effect Sizes (SizeCorr):

    1. Conceptual Definition: This metric evaluates how well the model can predict the magnitude of a perturbation's effect by correlating the number of differentially expressed genes in the true data with the number in the predicted data. It measures the ability to rank perturbations by their overall impact.
    2. Mathematical Formula: Let nt=Gt,true(DE)n_t = |G_{t, \mathrm{true}}^{(\mathrm{DE})}| be the number of true DEGs for perturbation tt, and n^t=Gt,pred(DE)\hat{n}_t = |G_{t, \mathrm{pred}}^{(\mathrm{DE})}| be the number of predicted DEGs for perturbation tt. SizeCorr=ρrank((nt)t=1T,(n^t)t=1T) \mathrm{SizeCorr} = \rho_{\mathrm{rank}} \big( (n_t)_{t=1}^T, (\hat{n}_t)_{t=1}^T \big)
    3. Symbol Explanation:
      • ntn_t: The total number of differentially expressed genes (adjusted p-value <0.05<0.05) identified from the true data for perturbation tt.
      • n^t\hat{n}_t: The total number of differentially expressed genes (adjusted p-value <0.05<0.05) identified from the model's predicted data for perturbation tt.
      • (nt)t=1T(n_t)_{t=1}^T: The vector of true DEG counts across all TT perturbations.
      • (n^t)t=1T(\hat{n}_t)_{t=1}^T: The vector of predicted DEG counts across all TT perturbations.
      • ρrank(,)\rho_{\mathrm{rank}}(\cdot, \cdot): The Spearman rank correlation coefficient, which measures the monotonic relationship between the two lists of DEG counts. A higher Spearman correlation indicates better agreement in ranking perturbation effect sizes.

5.2.3. Cell Embedding Evaluation Metrics

  • Intrinsic Evaluation:

    1. Conceptual Definition: This assesses the quality of the cell embeddings themselves by testing how much perturbation-specific information they retain. A shallow MLP classifier is trained to predict the perturbation label directly from the cell embedding. High classification performance suggests that the embeddings capture distinct perturbation-induced states.
    2. Mathematical Formula: Not explicitly provided for this specific evaluation, but typically involves AUROC (Area Under the Receiver Operating Characteristic curve) and accuracy for a multi-class classification task with cross-entropy loss.
    3. Symbol Explanation: AUROC measures the ability of a classifier to distinguish between classes. Accuracy is the proportion of correctly classified instances.
  • Extrinsic Evaluation:

    1. Conceptual Definition: This evaluates the practical utility of cell embeddings by using them as input for the ST model and assessing ST's downstream performance. It aims to see if embeddings that are rich in information (high intrinsic score) also lead to better perturbation prediction by ST.
    2. Mathematical Formula: Not explicitly provided, but involves comparing the performance of ST (using the metrics described above, like PDiscNorm, DE Overlap, SizeCorr) when trained with different embedding spaces.
    3. Symbol Explanation: ST's performance metrics are used here to compare the utility of different cell embeddings.

5.2.4. Other Evaluations

  • Cell Set Scaling:
    1. Conceptual Definition: An ablation study to understand the impact of the cell set size (SS) on STATE's performance. It investigates how varying SS affects validation loss while controlling computational cost (FLOPs). This also compares STATE to simpler models like a pseudobulk model (mean pooling instead of self-attention) and a single-cell variant (cellsetsize=1cell set size = 1).
    2. Mathematical Formula: Not explicitly provided, but involves plotting validation loss against FLOPs for different cell set sizes.
    3. Symbol Explanation:
      • cell set size: The number of cells grouped together in one set for ST processing.
      • batch size: The number of cell sets processed concurrently.
      • FLOPs: Floating-point operations, a measure of computational cost.
      • validation loss: The model's error on a held-out validation dataset.

5.2.5. Task Description for Generalization

The paper defines two generalization tasks:

  • Underrepresented Context Generalization Task:

    1. Conceptual Definition: This task evaluates a model's ability to generalize to a cellular context where only a small fraction of perturbations have been observed during training. It simulates scenarios where limited data is available for a new cell type or condition.
    2. Procedure: For a chosen test context cc^*, a small proportion (α=0.30\alpha = 0.30) of its perturbations form a support set for training, and the rest are held out as a target set for testing. The model trains on all perturbations from other contexts and the support set from cc^*. It is then evaluated on the target set of cc^*.
    3. Data Splits:
      • Support set PcsupPc\mathcal{P}_{c^*}^\mathrm{sup} \subset \mathcal{P}_{c^*}, where Pcsup=αPc|\mathcal{P}_{c^*}^\mathrm{sup}| = \lfloor \alpha |\mathcal{P}_{c^*}| \rfloor.
      • Target set Pctarget=PcPcsup\mathcal{P}_{c^*}^\mathrm{target} = \mathcal{P}_{c^*} \setminus \mathcal{P}_{c^*}^\mathrm{sup}.
      • Training data Dtrain={(Xp,c,p,c):cC{c},pPc}{(Xp,c,p,c):pPcsup}\mathcal{D}_{\mathrm{train}} = \bigr\{ \bigl( X_{p,c}, p, c \bigr) : c \in \mathcal{C} \setminus \{ c^* \}, p \in \mathcal{P}_c \bigr\} \cup \big\{ \bigl( X_{p,c^*}, p, c^* \bigr) : p \in \mathcal{P}_{c^*}^\mathrm{sup} \big\}.
      • Test data Dtest={(Xp,c,p,c):pPctarget}\mathcal{D}_{\mathrm{test}} = \bigl\{ \bigl( X_{p,c^*}, p, c^* \bigr) : p \in \mathcal{P}_{c^*}^\mathrm{target} \bigr\}.
    4. Notes: Unperturbed cells from all contexts (including cc^*) are available during training. If C|\mathcal{C}| is large, a fixed subset Ctest\mathcal{C}_{\mathrm{test}} is used. If C|\mathcal{C}| is small, an iterative leave-one-out approach is used.
  • Zero-Shot Context Generalization Task:

    1. Conceptual Definition: This task assesses a model's ability to predict perturbation effects in entirely novel cellular contexts for which no perturbation data was seen during fine-tuning. It relies on the model's ability to learn generalizable relationships.
    2. Procedure: The model is pre-trained on large perturbation datasets (e.g., Tahoe-100M). Then, for a query dataset, one context is held out (Ctest\mathcal{C}_{\mathrm{test}}), and the remaining contexts (Cfinetune\mathcal{C}_{\mathrm{fine-tune}}) are used for fine-tuning. The fine-tuned model is evaluated on its ability to predict perturbations within the held-out context of the query dataset.
    3. Data Splits:
      • Fine-tuning data Dfinetune={(Xp,c,p,c):cCfinetune,pP}\mathcal{D}_{\mathrm{fine-tune}} = \left\{ \left( X_{p,c}, p, c \right) : c \in \mathcal{C}_{\mathrm{fine-tune}}, p \in \mathcal{P} \right\}.
      • Test data Dtest={(Xp,c,p,c):cCtest,pP}\mathcal{D}_{\mathrm{test}} = \big\{ ( X_{p,c}, p, c ) : c \in \mathcal{C}_{\mathrm{test}}, p \in \mathcal{P} \big\}.
    4. Notes: Unperturbed cells from all contexts (including held-out) are available.

5.3. Baselines

STATE was benchmarked against both simple statistical methods and state-of-the-art deep learning models.

  • Perturbation Mean Baseline [767]:

    1. Conceptual Definition: This baseline assumes that the effect of a perturbation is a global average shift applicable across all cell types. It predicts a perturbed profile by adding a global perturbation-specific offset (calculated from training data) to the control mean of the target cell type.
    2. Mathematical Formula: For each cell type cc and perturbation pp:
      • Cell-type specific control mean: μcctrl=1CciCcx(i)\pmb{\mu}_c^{\mathrm{ctrl}} = \frac{1}{|\mathcal{C}_c|} \sum_{i \in \mathcal{C}_c} \mathbf{x}^{(i)}, where Cc\mathcal{C}_c is the set of control cells of type cc.
      • Cell-type specific perturbed mean: μc,ppert=1Pc,piPc,px(i)\pmb{\mu}_{c,p}^{\mathrm{pert}} = \frac{1}{|\mathcal{P}_{c,p}|} \sum_{i \in \mathcal{P}_{c,p}} \mathbf{x}^{(i)}, where Pc,p\mathcal{P}_{c,p} is the set of perturbed cells of type cc receiving perturbation pp.
      • Cell-type specific perturbation offset: δc,p=μc,ppertμcctrl\delta_{c,p} = \pmb{\mu}_{c,p}^{\mathrm{pert}} - \pmb{\mu}_c^{\mathrm{ctrl}}.
      • Global perturbation offset: δp = 1CpcCpδc,p\pmb{\delta}_p \ = \ \frac{1}{|\mathcal{C}_p|} \sum_{c \in \mathcal{C}_p} \delta_{c,p}, where Cp={cPc,p>0}\mathcal{C}_p = \{ c \mid |\mathcal{P}_{c,p}| > 0 \}. Given a test cell type tt and perturbation label pp, the model predicts: x^=μtctrl+δp,δctrl0. \hat{\mathbf{x}} = \pmb{\mu}_t^{\mathrm{ctrl}} + \pmb{\delta}_p, \qquad \pmb{\delta}_{\mathrm{ctrl}} \equiv \mathbf{0}.
    3. Symbol Explanation:
      • Cc\mathcal{C}_c: Set of control cells of cell type cc.
      • Pc,p\mathcal{P}_{c,p}: Set of cells of cell type cc perturbed with pp.
      • x(i)\mathbf{x}^{(i)}: Expression vector of cell ii.
      • μcctrl\pmb{\mu}_c^{\mathrm{ctrl}}: Mean expression of control cells in context cc.
      • μc,ppert\pmb{\mu}_{c,p}^{\mathrm{pert}}: Mean expression of cells perturbed with pp in context cc.
      • δc,p\delta_{c,p}: Difference between perturbed and control means for context cc and perturbation pp.
      • Cp\mathcal{C}_p: Set of contexts where perturbation pp was observed.
      • δp\pmb{\delta}_p: Global average offset for perturbation pp across contexts.
      • x^\hat{\mathbf{x}}: Predicted expression profile.
      • μtctrl\pmb{\mu}_t^{\mathrm{ctrl}}: Control mean for the test cell type tt.
      • δctrl\pmb{\delta}_{\mathrm{ctrl}}: Offset for control perturbations (defined as zero).
  • Context Mean Baseline [775]:

    1. Conceptual Definition: This baseline predicts a cell's post-perturbation profile by simply returning the average perturbed expression observed in the training set for cells of the same cell type. If the cell is a control, its original profile is passed through.
    2. Mathematical Formula: For every cell type cc, the pseudo-bulk mean of all non-control perturbed cells in the training set is formed: μc=1TciTcx(i),Tc={icell_type(i)=c,p(i)ctrl}. \mu_c = \frac{1}{|\mathcal{T}_c|} \sum_{i \in \mathcal{T}_c} \mathbf{x}^{(i)}, \qquad \mathcal{T}_c = \{ i \big| \mathrm{cell}\_\mathrm{type}(i) = c, p^{(i)} \neq \mathrm{ctrl} \}. At inference time, for a test cell ii with cell type c(i)c^{(i)} and perturbation label p(i)p^{(i)}, the prediction is: x^(i)={x(i)p(i)=ctrl,μc(i)p(i)ctrl. \hat{\mathbf{x}}^{(i)} = \left\{ \begin{array} { l l } { \mathbf{x}^{(i)} } & { p^{(i)} = \mathrm{ctrl}, } \\ { \pmb{\mu}_{c^{(i)}} } & { p^{(i)} \neq \mathrm{ctrl}. } \end{array} \right.
    3. Symbol Explanation:
      • μc\mu_c: Mean expression of all non-control perturbed cells in context cc from the training set.
      • Tc\mathcal{T}_c: Set of training cells of type cc that are perturbed (not controls).
      • x(i)\mathbf{x}^{(i)}: Expression vector of cell ii.
      • c(i)c^{(i)}: Cell type of cell ii.
      • p(i)p^{(i)}: Perturbation label of cell ii.
      • x^(i)\hat{\mathbf{x}}^{(i)}: Predicted expression profile for cell ii.
  • Linear Baseline [780] (as described in [46]):

    1. Conceptual Definition: This model treats a perturbation as a low-rank, gene-wide linear displacement added to a cell's control expression. It learns a linear map that transforms gene embeddings and perturbation embeddings into an expression change.
    2. Mathematical Formula:
      • GRG×dgG \in \mathbb{R}^{G \times d_g}: Fixed gene-embedding matrix (e.g., pretrained protein feature vectors).
      • PRP×dpP \in \mathbb{R}^{P \times d_p}: Fixed perturbation-embedding matrix (one-hot).
      • From the training set, an "expression-change" pseudobulk YY is built: Yg,p=1PpiPp(xgpert,(i)xgctrl,(i)),Pp={ip(i)=p}. Y_{g,p} = \frac{1}{|\mathcal{P}_p|} \sum_{i \in \mathcal{P}_p} \bigl( x_g^{\mathrm{pert},(i)} - x_g^{\mathrm{ctrl},(i)} \bigr), \qquad \mathcal{P}_p = \bigl\{ i \mid p^{(i)} = p \bigr\}. So YRG×PY \in \mathbb{R}^{G \times P} stores the average change relative to that cell's matched control for every gene gg and perturbation pp.
      • The model seeks a low-rank map KRdg×dpK \in \mathbb{R}^{d_g \times d_p} and a gene-wise bias bRG\mathbf{b} \in \mathbb{R}^G such that YGKP+b1Y \approx G K P^\top + \mathbf{b} \mathbf{1}^\top.
      • KK is obtained by solving the ridge-regularized least-squares problem: minKYGKPb1F2+λKF2,b=1PY1. \operatorname*{min}_{K} \left\| Y - G K P^\top - \mathbf{b} \mathbf{1}^\top \right\|_F^2 + \lambda \| K \|_F^2, \qquad \mathbf{b} = \frac{1}{P} Y \mathbf{1}.
      • Its closed-form solution is: K=(GG+λI)1GYP(PP+λI)1. K = \big( G^\top G + \lambda I \big)^{-1} G^\top Y P \big( P^\top P + \lambda I \big)^{-1}.
      • For a test cell ii with control profile xctrl,(i)\mathbf{x}^{\mathrm{ctrl},(i)} and perturbation label p(i)p^{(i)}, the prediction is: x^(i)={xctrl,(i)p(i)=ctrl,xctrl,(i)+GKPp(i)+bp(i)ctrl, \begin{array} { r l r } { \hat{\mathbf{x}}^{(i)} } & { = } & { \left\{ \begin{array} { l l } { \mathbf{x}^{\mathrm{ctrl},(i)} } & { p^{(i)} = \mathrm{ctrl}, } \\ { \mathbf{x}^{\mathrm{ctrl},(i)} + G K P_{p^{(i)}} + \mathbf{b} } & { p^{(i)} \neq \mathrm{ctrl}, } \end{array} \right. } \end{array}
    3. Symbol Explanation:
      • GG: Gene embedding matrix.
      • PP: Perturbation embedding matrix.
      • Yg,pY_{g,p}: Average expression change for gene gg under perturbation pp.
      • Pp\mathcal{P}_p: Set of cells receiving perturbation pp.
      • xgpert,(i)x_g^{\mathrm{pert},(i)}: Expression of gene gg in perturbed cell ii.
      • xgctrl,(i)x_g^{\mathrm{ctrl},(i)}: Expression of gene gg in control cell ii.
      • KK: Low-rank map.
      • b\mathbf{b}: Gene-wise bias.
      • 1\mathbf{1}: Vector of ones.
      • F2\|\cdot\|_F^2: Frobenius norm squared.
      • λ\lambda: Ridge regularization parameter.
      • II: Identity matrix.
      • Pp(i)P_{p^{(i)}}: Row of PP corresponding to perturbation p(i)p^{(i)}.
  • Deep Learning Baselines [793]:

    • scVI [25]: An autoencoder-based model that models gene expression distributions while accounting for technical noise and batch effects.
    • CPA [8, 47]: An autoencoder-based model that learns a compositional latent space to capture additive effects of perturbation, dosage, and cell type.
    • scGPT [11]: A transformer-based foundation model that uses generative pretraining on large datasets (over 33 million cells) to achieve zero-shot generalization across tasks, including perturbation prediction.

5.4. ST Hyperparameters

The following are the results from Table 3 of the original paper:

Dataset cell_set_size hidden_dim n_encoder layers n_decoder layers batch encoder transformer backbone _key attn_heads params
Tahoe-100M 256 1488 4 4 false LLaMA 12 244M
Parse-PBMC 512 1440 4 true LLaMA 12 244M
Replogle-Nadig 32 128 4 4 false GPT2 8 10M

The following are the results from Table 4 of the original paper:

Component Architecture Layer Dimensions Activation Normalization Dropout
fcell 4-layer MLP (G or E) → h → h → h → h GELU LayerNorm None
fpert 4-layer MLP Dpert → h → h → h → h GELU LayerNorm None
fbatch Embedding Lyr. Dbatch → h N/A N/A None
fsT LLaMA Transf. h (input to each of 4 layers) SwiGLU RMSNorm None
recon Linear Layer h → G or E N/A N/A None
fdecode 3-layer MLP h → 1024 → 512 → G GELU LayerNorm 0.1
fconf 3-layer MLP h → h/2 → h/4 → 1 GELU LayerNorm None

Notes on Architectural Details:

  • GG: Number of genes.
  • EE: Dimension of cell embeddings (if used).
  • DpertD_{\mathrm{pert}}: Dimension of perturbation features.
  • DbatchD_{\mathrm{batch}}: Dimension of batch features.
  • hh: Shared hidden dimension.
  • All MLPs use GELU activation and Layer Normalization before activation, unless otherwise specified.
  • f_ST (transformer backbone) uses SwiGLU (Swish-Gated Linear Unit) activation and RMSNorm (Root Mean Square Normalization).

6. Results & Analysis

6.1. Core Results Analysis

The experimental results consistently demonstrate STATE's superior performance in predicting cellular responses to perturbations, particularly highlighting its ability to generalize across diverse contexts and leverage large datasets.

The multi-panel Figure 2 from the original paper, titled "STATE improves perturbation prediction in context generalization, leverages data scale, and enables cross-dataset transfer learning," provides visual evidence for these claims.

Figure 2: STATE improves perturbation prediction in context generalization, leverages data scale, and enables cross-dataset transer learning. (A) Underrepresented context generalization task.Models w… Figure 2: STATE improves perturbation prediction in context generalization, leverages data scale, and enables cross-dataset transer learning. (A) Underrepresented context generalization task.Models weretrained o pa o ec perubationslargely eldout ndeprenttare cntext.B)Models weeraine an evalut aa 304 dtas . peo arehmodel perubati atn [4], vin predicted vs true DEs, and Spearman correlation of perturbations by efect ize (numberof predicted DEGs). (D) Equl acuracy perturbed cells in the embedding gnerated using observed data; extrinsic performance measures the clasifcation accuracy over perturbed embeddings predicted by ST trained on the cel embeddings. (E) Zero-shot pe t , after pre-training on Tahoe-100M.

Performance in Context Generalization (Figure 2C)

STATE was evaluated on several metrics in a few-shot context generalization task, where each test cell context contributed only 30% of its perturbations to training (Figure 2B).

  • Perturbation Discrimination Score (PDiscNorm): STATE achieved substantial absolute improvements:
    • 54% on the Tahoe-100M dataset (a large drug perturbation dataset).
    • 29% on the Parse-PBMC dataset (a cytokine signaling perturbation dataset).
    • It also matched the next best performing baseline on genetic perturbations. This indicates STATE's strong ability to distinguish between different perturbation effects, even with limited exposure in a new context.
  • DE Overlap Accuracy: STATE showed significantly improved accuracy in identifying differentially expressed genes (DEGs):
    • Twice as good as the next best baseline on Tahoe-100M.
    • 43% better on the Parse-PBMC dataset.
    • It was the second best model on genetic perturbations. This highlights STATE's biological relevance, as accurately predicting DEGs is crucial for understanding molecular mechanisms.
  • Effect Sizes (SizeCorr): STATE accurately ranked perturbations by their relative effect sizes, demonstrating its capacity to predict the overall impact of perturbations:
    • 53% higher Spearman correlations on Parse-PBMC.
    • 22% higher than baselines on ReplogleNadig.
    • 70% higher on Tahoe-100M, approaching an absolute correlation of 0.8.

Leveraging Data Scale

A notable finding is that STATE provides the biggest benefit as the data scale increases. While performing well on genetic perturbations (where effect sizes are weaker and data is less abundant), its performance is significantly better on signaling perturbations and drug perturbations (which represent one and two more orders of magnitude of data, respectively). This suggests that previous perturbation models are in a data-poor regime [49, 50], and STATE's architecture is better designed to leverage vast amounts of data effectively.

Shared Cell Embedding (SE) for Cross-Dataset Transfer (Figure 2D and 2E)

  • Information Content in SE (Figure 2D): SE's embeddings were evaluated for their intrinsic and extrinsic performance. Intrinsic evaluation measured how well a classifier could predict perturbation labels from the embeddings. Extrinsic evaluation measured ST's performance when using these embeddings. STATE embeddings significantly improved over other models in both aspects, outperforming even dataset-specific representations using highly-varying genes. This indicates that SE effectively captures perturbation-specific information while being robust.
  • Zero-Shot Transfer Learning (Figure 2E): STATE models (ST+SEST+SE) were pretrained on Tahoe-100M and fine-tuned on smaller datasets for zero-shot context generalization. Across all tested datasets, ST+SEST+SE enabled better transfer than ST+HVGST+HVG (ST trained on Highly Variable Genes) and consistently outperformed other baseline methods. This demonstrates the critical role of SE in enabling STATE to generalize to novel cellular contexts where no perturbation data was seen during fine-tuning, paving the way for virtual cell models.

6.2. Data Presentation (Tables)

The following are the results from Table 1 of the original paper, summarizing datasets used in ST experiments:

Dataset # of Cells # of Perturbations # of Contexts
Replogle-Nadig 624,158 1,677 4
Jiang 234,845 24 30
Srivatsan 762,795 189 3
Mcfaline-Figueroa 354,758 122 3
Tahoe-100M 100,648,790 1138 50
Parse-PBMC 9,697,974 90 12/18 (Donors)/(Cell Types)

The following are the results from Table 2 of the original paper, summarizing datasets used in SE experiments:

Dataset # Training Cells # Validation Cells
Arc scBaseCount [32] 71,676,369 4,137,674
CZ CellXGene [31] 59,233,790 6,500,519
Tahoe-100M [30] 36,157,383 2,780,587

The following are the results from Table 3 of the original paper, detailing key model hyperparameters by dataset:

Dataset cell_set_size hidden_dim n_encoder layers n_decoder layers batch encoder transformer backbone _key attn_heads params
Tahoe-100M 256 1488 4 4 false LLaMA 12 244M
Parse-PBMC 512 1440 4 true LLaMA 12 244M
Replogle-Nadig 32 128 4 4 false GPT2 8 10M

The following are the results from Table 4 of the original paper, providing architectural details for ST components:

Component Architecture Layer Dimensions Activation Normalization Dropout
fcell 4-layer MLP (G or E) → h → h → h → h GELU LayerNorm None
fpert 4-layer MLP Dpert → h → h → h → h GELU LayerNorm None
fbatch Embedding Lyr. Dbatch → h N/A N/A None
fsT LLaMA Transf. h (input to each of 4 layers) SwiGLU RMSNorm None
recon Linear Layer h → G or E N/A N/A None
fdecode 3-layer MLP h → 1024 → 512 → G GELU LayerNorm 0.1
fconf 3-layer MLP h → h/2 → h/4 → 1 GELU LayerNorm None

6.3. Ablation Studies / Parameter Analysis

The paper specifically highlights an ablation study on Cell Set Scaling (Figure 1C from the original paper, as described in the image alt text).

Figure 1: StATE is a multi-scale machine learning architecture that operates across genes, individual cells, and cell populations. A) The State Transition model (ST) learns perturbation effects by tr… Figure 1: StATE is a multi-scale machine learning architecture that operates across genes, individual cells, and cell populations. A) The State Transition model (ST) learns perturbation effects by training on sets o pubea e hat erti ye bae l p p on the Tahoe-100M dataset achieved when covariate-matched groups are chunked into sets of 256 cells. The full ST model significantly outperforms a pseudobulk model (STATE w/ mean-pooling instead of sel-attention) and a single-cell variant (STATE with set size =1= 1 ). Removing the self-attention mechanism (STATE w/o self-attention) substantialy degrades performance.C) The StateEmbedin model (SE) is an encoder-decodermodel traieon laralsevational ng-cerntomiataTheranormeencoonstructdensbe the cel and the MLP decoder reconstructs gene expression from the embedding. ST can operate directly on gene exression profl rn el representations from SWhen trained with SE, a separate MLPecodes fro preic embeddings to perturbed gene expression.

This study investigated the impact of the cell set size in STATE.

  • The authors varied the number of cells per set and the batch size such that the product (batch size ×\times cell set size) remained constant at 16,384. They measured the validation loss as a function of floating-point operations (FLOPs).
  • Key Findings:
    • As the cell set size increased, the validation loss significantly improved on held-out cell lines.

    • An optimal cell set size of 256 was observed, beyond which returns diminished. This supports the ST module's design choice of modeling cell populations rather than individual cells.

    • Comparison to Baselines: The full ST model significantly outperformed a pseudobulk model (which replaced self-attention with mean-pooling). This indicates the advantage of self-attention in capturing intra-set heterogeneity over simple averaging.

    • Importance of Self-Attention: A single-cell variant (STATE with setsize=1set size = 1) performed poorly. Crucially, removing the self-attention mechanism (STATE w/o self-attention) substantially degraded performance. This highlights the indispensable role of the self-attention mechanism in enabling ST to effectively model perturbation-induced transformations and account for cellular heterogeneity.

      Another implicit ablation is the comparison between ST+HVGST+HVG and ST+SEST+SE in transfer learning (Figure 2E). This comparison demonstrates that training ST using SE's rich cell embeddings significantly improves cross-dataset transferability compared to using highly variable genes directly. This validates the design choice of SE as a crucial component for STATE's generalization capabilities.

7. Conclusion & Reflections

7.1. Conclusion Summary

The paper introduces STATE, a novel multi-scale, transformer-based model that significantly advances the prediction of cellular responses to perturbations across diverse contexts. STATE's architecture comprises two key modules: the State Transition (ST) model, which uses self-attention to learn perturbation effects on sets of cells, thereby capturing cellular heterogeneity; and the State Embedding (SE) model, a pre-trained encoder-decoder that generates robust cell embeddings from vast observational data.

Key findings include STATE's superior performance in perturbation discrimination and differential gene expression identification on large-scale datasets, with improvements exceeding 30% over existing models. The model effectively leverages giga-cell datasets (over 100 million perturbed cells and 167 million observational cells), showing that its performance scales with data availability. Furthermore, STATE demonstrates strong context generalization and zero-shot prediction capabilities in novel cellular environments, which were unseen during training. This is largely attributed to the high-quality SE embeddings and ST's ability to model distributional shifts via MMD loss. The theoretical analysis also establishes a connection between ST's solution family and optimal transport maps, providing a principled understanding of its capacity to learn cellular transformations.

Overall, STATE sets a new benchmark for perturbation response prediction, laying the groundwork for developing scalable virtual cell models for various biological and therapeutic applications.

7.2. Limitations & Future Work

The authors implicitly or explicitly mention several limitations and avenues for future work:

  • Theoretical Guarantees for Optimal Transport: While Theorem 2 shows that the optimal transport map TT is within ST's solution family F^\hat{\mathcal{F}}, it also notes that F^\hat{\mathcal{F}} contains infinitely many other elements. This means ST is not guaranteed to learn the unique optimal transport solution without further constraints.
  • Constraining the Jacobian: Theorem 3 suggests that imposing explicit constraints on the Jacobian of ST's mapping (to be symmetric and semi-positive definite) would ensure it learns the unique OT map. However, implementing these constraints in a deep learning architecture, especially for transformers, is a non-trivial challenge that might require novel architectural designs, similar to Input Convex Neural Networks (ICNNs) [9, 84].
  • Implicit Bias Characterization: The authors hypothesize that the implicit bias [85-87] of gradient descent during ST's optimization might lead it to discover solutions that resemble optimal transport mappings by favoring "minimal shifts." Characterizing this implicit bias in transformer-based transition models is explicitly stated as a promising area for future research.
  • Computational Cost: Training STATE requires immense computational resources, as evidenced by its use of Flash Attention 2, mixed precision training, and distributed training across multiple NVIDIA H100 GPUs on 100M+ cells. This might limit its accessibility for researchers without high-performance computing infrastructure.
  • Interpretability: While transformers are powerful, their complex, black-box nature can make it challenging to fully interpret why specific predictions are made or how heterogeneity is being modeled, which could be important for biological discovery.

7.3. Personal Insights & Critique

This paper presents a highly innovative and impactful approach to a critical problem in systems biology and drug discovery. My personal insights include:

  • Holistic Multi-Scale Design: The most compelling aspect of STATE is its holistic, multi-scale architecture. By integrating protein language models for gene embeddings, a massive observational pre-training for robust cell embeddings (SE), and a self-attention-based transformer (ST) for population-level perturbation modeling, it addresses the problem from multiple angles simultaneously. This layered approach is a significant improvement over models that focus solely on gene expression or single-cell representations.
  • Leveraging Data Scale Effectively: The explicit demonstration that STATE's performance significantly improves with increasing data scale is a powerful validation. This positions STATE as a true foundation model for cell biology, capable of extracting deep insights from the ever-growing single-cell omics datasets. The finding that other models are "data-poor" provides crucial context for the field.
  • Distributional Approach with MMD: The use of MMD as a loss function to compare distributions of cell states is elegant and biologically sound, especially given the destructive nature of scRNA-seq. It allows the model to learn population-level changes and heterogeneity rather than just point-wise predictions, which is more reflective of real biological phenomena.
  • Theoretical Grounding: The theoretical connection to Optimal Transport is a strong addition, providing a mathematical framework to understand STATE's capacity. While the current model doesn't explicitly enforce OT constraints, the possibility to do so in future work opens exciting avenues for more principled cellular state transformations.
  • Potential for Diverse Applications: The model's ability to generalize across genetic, signaling, and chemical perturbations and to novel contexts makes it incredibly versatile. It could directly impact drug repurposing, target identification, and even guide autonomous experimental design systems by providing accurate in-silico predictions.

Potential Issues or Areas for Improvement:

  • Computational Accessibility: While its scale is a strength, it also presents a barrier. Future work might explore methods for distillation or more efficient architectures to make such powerful models accessible to a wider range of researchers without giga-scale computational resources.

  • Interpretability of Heterogeneity: While STATE models heterogeneity, it would be valuable to have tools to interpret what kind of heterogeneity is being captured by the self-attention mechanism and SE embeddings. This could lead to novel biological discoveries about cell state transitions.

  • Beyond Gene Expression: The current model focuses on gene expression. Integrating other omics modalities (e.g., epigenomics, proteomics) within STATE's multi-scale framework could further enhance its predictive power and provide a more complete virtual cell model.

  • Dynamic Modeling: STATE models a static transition from control to perturbed states. Future extensions could explore dynamic modeling of cellular processes over time, which would be crucial for understanding disease progression or drug response kinetics.

    In conclusion, STATE is a landmark paper that pushes the boundaries of single-cell perturbation prediction, demonstrating the power of large-scale transformer models and principled distributional learning to unlock deeper biological understanding and accelerate biomedical discovery.

Similar papers

Recommended via semantic vector search.

No similar papers found yet.