AiPaper
论文状态:已完成

Predicting cellular responses to perturbation across diverse contexts with STATE

发表:2025/06/27
原文链接
价格:0.10
已有 5 人读过
本分析由 AI 生成,可能不完全准确,请以原文为准。

TL;DR 精炼摘要

本文提出基于Transformer的STATE模型,利用逾1亿细胞扰动表达数据,精准预测不同细胞背景下的扰动响应。模型提升效应区分度30%以上,准确识别差异表达基因,且能泛化到未见新背景的强扰动,推动可扩展虚拟细胞模型的发展。

摘要

Predicting cellular responses to perturbation across diverse contexts with State Anonymous Author(s) Affiliation Address email Abstract 1 Cellular responses to perturbations are a cornerstone for understanding biological mechanisms and 2 selecting drug targets. While machine learning models offer tremendous potential for predicting perturbation 3 effects, they currently struggle to generalize to unobserved cellular contexts. Here, we introduce State , 4 a transformer model that predicts perturbation effects while accounting for cellular heterogeneity within 5 and across experiments. State predicts perturbation effects across sets of cells and is trained using gene 6 expression data from over 100 million perturbed cells. State improved discrimination of effects on large 7 datasets by more than 30% and identified differentially expressed genes across genetic, signaling and chemical 8 perturbations with significantly improved accuracy. Using its cell embedding trained on observational data 9 from 167 million cells, State identified strong perturbations in novel cellular contexts where no perturbations 10 were observed during training. Overall, the perfo

思维导图

论文精读

中文精读

1. 论文基本信息

1.1. 标题

预测跨多样化背景下细胞对扰动的响应,使用 STATE 模型(Predicting cellular responses to perturbation across diverse contexts with STATE)

1.2. 作者

匿名作者(Anonymous Author(s))

1.3. 发表期刊/会议

该论文已提交至第39届神经信息处理系统大会(Neural Information Processing Systems, NeurIPS 2025),目前处于提交阶段,未公开发布(Do not distribute)。这表明它可能是一篇预印本或正在评审中的论文,尚未经过同行评审的最终版本。

1.4. 发表年份

2025年(根据 NeurIPS 2025 推断)

1.5. 摘要

细胞对扰动(perturbation)的响应对于理解生物学和选择药物靶点至关重要,然而现有的机器学习模型在跨越多样化的细胞背景(cellular contexts)时,其泛化能力不足。本文引入了 STATE 模型,这是一个基于 Transformer 的模型,旨在预测扰动效应,同时考虑实验内部和实验之间细胞的异质性(heterogeneity)。STATE 使用来自超过1亿个被扰动细胞的基因表达数据,预测跨细胞集合的扰动效应。它在大型数据集上将效应区分度提高了30%以上,并更准确地识别了遗传、信号和化学扰动下的差分表达基因(differentially expressed genes, DEGs)。STATE 利用从1.67亿个观测细胞数据训练得到的细胞嵌入(cell embeddings),成功预测了训练期间未见的全新细胞背景中的强扰动。该模型的性能和灵活性推动了可扩展虚拟细胞模型(scalable virtual cell models)在多样化生物应用中的发展。

1.6. 原文链接

/files/papers/690878ad1ccaadf40a4344dd/paper.pdf (发布状态:预印本/提交中)


2. 整体概括

2.1. 研究背景与动机

论文试图解决的核心问题: 细胞对扰动的响应是生物学研究和药物发现的基石。通过选择性地干扰细胞系统中的特定组分,科学家能够识别表型(phenotypes)的因果驱动因素,这对于靶点识别和药物开发都至关重要。尽管机器学习模型在预测扰动效应方面具有巨大潜力,但它们目前难以泛化到未曾观测到的细胞背景。现有的计算方法在从大规模扰动数据中获得与数据增长相称的预测能力提升方面遇到了挑战。

为什么这个问题在当前领域是重要的?现有研究存在哪些具体的挑战或空白:

  1. 泛化能力差: 现有的机器学习模型在将其从训练数据中学到的扰动响应知识泛化到新的、未曾见过的细胞背景时表现不佳。
  2. 单细胞数据限制: 单细胞 RNA 测序(scRNA-seq)是一种破坏性检测方法,这意味着无法同时观测到细胞的扰动前状态和扰动后状态,使得准确推断每个细胞的特定扰动响应变得困难。
  3. 异质性处理不足:
    • 一些方法假设细胞群体内的异质性(within-population heterogeneity)相对于扰动效应是可忽略的,简单地将受扰细胞映射到具有共享协变量的未扰动细胞。这些方法在扰动效应微妙时,或未扰动群体中的异质性超过扰动信号时,往往会失败。
    • 其他模型将细胞群体视为分布,学习数据生成分布或显式地解耦标签和未标签的变异源。然而,在实践中,这些模型常常未能超越那些不显式建模分布结构的方法。
  4. 可扩展性问题: 基于最优传输(Optimal Transport, OT)的方法虽然已被提出,但其适用性受到假设限制和可扩展性差的影响。

这篇论文的切入点或创新思路: 为了克服这些挑战,本文引入了 STATE,一个灵活且富有表现力的架构,用于建模细胞异质性以及在多样化数据集内部和跨数据集的扰动效应。STATE 通过以下核心创新点切入:

  1. 多尺度架构: 结合了 State Transition (ST) 模型和 State Embedding (SE) 模型,分别在细胞群体层面和个体细胞层面捕捉信息。
  2. 基于Transformer的 ST ST 利用自注意力(self-attention)机制在细胞集合上建模扰动引起的转换,能够灵活捕捉生物异质性,并且不依赖显式分布假设。
  3. 预训练的 SE 嵌入: SE 模型通过预训练生成高质量的细胞嵌入,这些嵌入对技术变异更鲁棒,并优化了扰动效应的检测。这种预训练机制使得模型能够从大量观测数据中学习通用细胞表示,从而提高在未见过背景下的泛化能力。
  4. 充分利用大规模数据: STATE 能够有效利用 SE 模块的1.67亿观测细胞数据和 ST 模块的1亿受扰细胞数据进行训练。

2.2. 核心贡献/主要发现

  1. 引入了 STATE 模型: 一个多尺度机器学习架构,包含 ST (State Transition) 和 SE (State Embedding) 两个互补模块,用于预测细胞对扰动的响应。

  2. ST 模型的创新设计: ST 是一个基于 Transformer 的模型,通过在细胞集合上应用自注意力机制,能够灵活地捕捉细胞群体内部和跨实验的生物异质性,而无需依赖显式的分布假设。它将控制细胞群体转化为相应的扰动状态。

  3. SE 模型的强大能力: SE 是一个预训练的编码器-解码器模型,通过学习多样化数据集中细胞间的基因表达变异来生成鲁棒的细胞嵌入。这些嵌入对技术变异更具鲁棒性,并优化了扰动效应的检测。

  4. 显著的性能提升:

    • 在大型数据集上,STATE 将效应区分度(effect discrimination)提高了30%以上。
    • 在遗传、信号和化学扰动任务中,STATE 更准确地识别了差分表达基因(DEGs),准确率比次优基线高出2倍(Tahoe-100M)或43%(Parse-PBMC)。
    • STATE 能够准确地根据效应大小对扰动进行排序,其Spearman相关系数比基线高出53%到70%。
  5. 卓越的泛化能力:

    • 在上下文泛化(context generalization)任务中,STATE 表现出优越的性能,尤其是在数据规模增加时,显示出其架构能够更好地利用大规模数据。
    • 利用 SE 训练的细胞嵌入,STATE 成功预测了在训练期间从未观测到的新细胞背景中的强扰动,实现了零样本(zero-shot)泛化能力。
  6. 可扩展性和灵活性: STATE 的性能和灵活性为开发用于多样化生物应用的可扩展虚拟细胞模型奠定了基础。


3. 预备知识与相关工作

3.1. 基础概念

  • 细胞扰动 (Cellular Perturbation): 指对细胞施加外部干预,例如使用药物、基因编辑(如 CRISPR)、改变培养条件等,以观察细胞状态(如基因表达、表型)如何变化的实验过程。
  • 单细胞RNA测序 (Single-cell RNA sequencing, scRNA-seq): 一种高通量生物技术,用于测量单个细胞内的基因表达水平。它能够揭示细胞群体内的异质性,但由于其破坏性,无法在扰动前后对同一细胞进行测量。
  • 基因表达 (Gene Expression): 指细胞中基因被转录(生成RNA)和/或翻译(生成蛋白质)的过程。通常,基因表达水平通过细胞内特定基因的 mRNA 分子数量来衡量。
  • 细胞异质性 (Cellular Heterogeneity): 指在看似同质的细胞群体中,单个细胞之间在基因表达、功能、形态等方面存在的差异。理解这种异质性对于准确预测细胞对扰动的响应至关重要。
  • Transformer 模型 (Transformer Model): 一种最初为自然语言处理设计的深度学习架构,其核心是自注意力机制。Transformer 模型能够高效处理序列数据,捕捉长距离依赖关系,并已成功应用于生物信息学领域,处理基因序列、蛋白质序列和单细胞数据。
  • 自注意力机制 (Self-Attention Mechanism): Transformer 模型的核心组成部分。它允许模型在处理序列中的每个元素(例如,一个细胞或一个基因)时,计算该元素与序列中所有其他元素之间的相关性,并根据这些相关性来加权整合信息。这使得模型能够灵活地捕捉序列内部的依赖关系,而无需依赖固定的局部感受野。
  • 细胞嵌入 (Cell Embedding): 将复杂、高维的单细胞数据(例如,数千个基因的表达谱)映射到低维连续向量空间中的表示。这些嵌入能够捕获细胞的生物学状态和相互关系,并作为下游机器学习任务的输入。
  • 差分表达基因 (Differentially Expressed Genes, DEGs): 在两种或多种不同实验条件(例如,经过扰动与未经扰动)下,其表达水平发生统计学显著变化的基因。识别 DEGs 是理解细胞对扰动响应机制的关键一步。
  • 预训练 (Pre-training) 与 微调 (Fine-tuning): 深度学习中的常用策略。
    • 预训练: 指在大规模、多样化的数据集上训练一个通用模型(例如,一个语言模型或一个细胞嵌入模型),使其学习到广泛的特征和知识。
    • 微调: 指在特定任务或较小数据集上,使用预训练模型的权重作为初始值,进一步训练模型以适应特定任务。这通常能显著提高模型在目标任务上的性能,尤其是在目标数据量有限时。
  • 伪批量 (Pseudobulk): 在单细胞 RNA 测序数据分析中,通过将一组具有相似特征(如细胞类型、批次或扰动条件)的单个细胞的基因表达数据进行聚合(例如,求和或求平均),来创建一个模拟传统批量 RNA 测序样本的数据点。这有助于在单细胞分辨率上进行差分表达分析,并减少数据稀疏性。
  • 最大均值差异 (Maximum Mean Discrepancy, MMD): 一种非参数的统计检验方法,用于度量两个概率分布在再生核希尔伯特空间(Reproducing Kernel Hilbert Space, RKHS)中的距离。MMD 广泛用于比较两个样本是否来自相同的分布,或在生成模型中作为损失函数,鼓励生成样本的分布与真实数据分布相匹配。

3.2. 前人工作

  1. 大规模扰动筛选: 近年来,功能基因组学(functional genomics)的进步,特别是通过配对池化 CRISPR 扰动等方法 [1-6],使得在大规模上对特定细胞背景进行扰动筛选成为可能。
  2. 计算预测方法:
    • 基于映射的方法 [10-12]: 一些方法假设群体内的异质性(within-population heterogeneity)相对于扰动效应可忽略,并通过将受扰细胞映射到具有相似协变量的随机选择的未受扰细胞来预测。这些方法在扰动效应微妙时,或未受扰群体中的异质性超过扰动信号时,泛化能力差。
    • 基于分布模型的方法 [8, 19-26]: 其他模型将细胞群体视为分布,学习数据生成分布或显式地解耦标签和未标签的变异源。然而,这些模型在实践中往往未能超越不显式建模分布结构的方法。
    • 基于最优传输的方法 [9, 27-29]: 这类方法通过将未受扰群体映射到受扰群体来学习转换。然而,其适用性受到假设和可扩展性差的限制。
  3. 扰动预测基准 [14-18]: 尽管扰动数据集的规模和范围迅速增长,但预测能力并未实现相应比例的提升,这凸显了现有模型的局限性。
  4. 通用细胞嵌入和基础模型 [11, 12, 30-40]: 发展通用细胞嵌入(universal cell embeddings)或单细胞基础模型(single-cell foundation models)是领域内的一个重要方向,旨在学习在多样化数据集和实验中鲁棒的细胞表示,以应对技术变异并优化扰动效应检测。

3.3. 技术演进

该领域的技术演进经历了从相对简单的统计模型到复杂的深度学习架构的转变:

  1. 早期方法(简单映射/统计):最初的方法倾向于简化细胞异质性,假设群体内部差异可忽略,或者通过简单的统计方法(如平均值)来描述细胞状态。这种方法在处理复杂或微妙的扰动时效果不佳。

  2. 基于分布模型(VAEs/Disentanglement):随着变分自编码器(VAEs)等生成模型的发展,研究者开始尝试将细胞群体视为概率分布,并试图学习这些分布的参数或解耦不同变异来源(如生物变异与技术变异)。然而,这些模型在实际应用中并未总是显著超越更简单的方法。

  3. 基于最优传输(Optimal Transport):最优传输提供了一个理论上严谨的框架来比较和转换概率分布,被应用于将未扰动细胞分布映射到扰动细胞分布。但其计算复杂性和对假设的敏感性限制了其大规模应用。

  4. 深度学习和Transformer的兴起:近年来,深度学习,特别是Transformer架构,在处理复杂数据和捕捉长距离依赖方面展现出强大能力。一些工作开始探索使用Transformer构建单细胞基础模型(如 scGPT [11]),或者将其应用于扰动预测任务。

    本文的工作 STATE 正是处于这一演进的最前沿,它结合了 Transformer 强大的序列建模能力(通过自注意力机制捕捉细胞集合内的异质性),并引入了大规模预训练的细胞嵌入(SE)来解决跨数据集的泛化和技术变异问题。同时,它通过 MMD 损失函数隐式地与最优传输的思想建立联系,但在实践中避免了传统 OT 方法的计算复杂性。

3.4. 差异化分析

STATE 模型与现有方法相比,其核心区别和创新点在于:

  1. 对细胞异质性的显式建模: 现有方法要么忽略群体内异质性(如简单的映射方法),要么试图通过显式分布假设来建模(如 scVI, CPA),但往往效果不理想。STATE 中的 ST 模型通过在细胞集合上应用 Transformer 的自注意力机制,能够灵活地捕捉残余异质性(residual heterogeneity),即那些未被已知协变量(如细胞类型、扰动标签)解释的异质性。这种基于集合的建模方式,避免了对精确个体细胞扰动响应的不可行推断,并利用群体信息来推断扰动效应。

  2. 利用大规模观测数据进行鲁棒的细胞嵌入: STATE 引入的 SE 模型通过对1.67亿个观测细胞进行自监督预训练,学习到高质量、对技术变异鲁棒的细胞嵌入。这使得模型能够将不同实验和平台的数据统一到共享的表示空间中,从而极大地提高了跨数据集的泛化能力和迁移学习能力。相比之下,许多现有方法直接在原始基因表达数据上操作,容易受到技术噪声和批次效应的影响,导致泛化性差。

  3. 数据规模利用效率更高: 论文指出,随着数据规模的增加,STATE 的性能提升远超其他基线模型。这表明 STATE 的架构(特别是 Transformer 和大规模预训练)能够更好地利用海量数据,从而解决了现有扰动模型可能处于数据稀疏状态的问题。

  4. 零样本(Zero-shot)泛化能力: 结合 SE 学习到的通用细胞嵌入,STATE 能够成功预测在训练期间从未见过的全新细胞背景中的扰动效应。这对于药物发现和疾病机制研究具有重要意义,因为它允许研究人员在没有大量特定背景数据的情况下进行预测。

  5. 不依赖显式分布假设: ST 通过 MMD 损失函数在分布层面上进行优化,但它不强制性地依赖于特定的参数化数据生成分布假设(如高斯分布),这使得它在处理复杂和非高斯分布的生物数据时更具灵活性。

    总之,STATE 的核心创新在于其多尺度的 Transformer 架构,通过 STSE 模块的协同作用,有效地解决了现有模型在处理细胞异质性、跨数据集泛化以及利用大规模数据方面的局限性。


4. 方法论

STATE 是一种多尺度机器学习架构,由两个互补模块组成:State Transition (ST) 模型和 State Embedding (SE) 模型。ST 模型是一个 Transformer,它使用自注意力机制来建模细胞集合中的扰动效应。SE 模型则是一个预训练的编码器-解码器模型,用于学习鲁棒的细胞嵌入,以应对跨数据集的技术变异。

4.1. 方法原理

STATE 的核心思想是通过组合 STSE 模块,既能处理细胞群体内部的复杂异质性,又能有效地跨越不同的实验和细胞背景进行泛化。

  • ST 模型(状态转换模型) 旨在学习从控制细胞群体到受扰细胞群体的转录组响应。它通过将细胞分组为集合,并在这些集合上应用 Transformer 架构,利用自注意力机制来捕捉细胞内部未被标注的异质性,并预测扰动后的转录组状态。这种方法避免了对单个细胞扰动前后状态的直接观察(这是不可能的),而是通过群体层面的转换来推断效应。其训练目标是最小化预测群体与真实群体之间的 MMD 距离。
  • SE 模型(状态嵌入模型) 旨在解决跨数据集的技术变异和泛化问题。它通过自监督学习,从大规模观测单细胞数据中学习高质量的细胞嵌入。这些嵌入能够捕获细胞的生物学信号,同时减少技术伪影,为 ST 提供更鲁棒和统一的细胞表示。SE 的预训练目标包括基因表达预测和辅助数据集分类任务。

4.2. 核心方法详解

4.2.1. STATE的整体架构

以下是原文 Figure 1 的示意图,展示了 STATE 的多尺度机器学习架构:

Figure 1: StATE is a multi-scale machine learning architecture that operates across genes, individual cells, and cell populations. A) The State Transition model (ST) learns perturbation effects by tr… 该图像是论文中图1的示意图,展示了STATE模型的多尺度机器学习架构,包括状态转换模型(ST)对细胞扰动的预测过程、不同细胞集合大小对模型性能的影响,以及状态嵌入模型(SE)的编码解码结构,用于处理跨数据集的异质性和细胞表达重构。

  • A) 状态转换模型(ST): 通过在细胞集合上训练来学习扰动效应。输入可以是原始基因表达谱或 SE 学习的细胞表示。ST 利用自注意力机制处理细胞集合,捕获生物异质性。
  • B) 细胞集合大小的影响:Tahoe-100M 数据集上,当协变量匹配的组被分成256个细胞的集合时,ST模型的表现最佳。完整的 ST 模型显著优于伪批量模型(STATE w/ mean-pooling)和单细胞变体(STATEwithsetsize=1STATE with set size = 1)。移除自注意力机制(STATE w/o self-attention)会大大降低性能。
  • C) 状态嵌入模型(SE): 是一个编码器-解码器模型,在大量观测单细胞组学数据上训练。Transformer 编码器构建密集的细胞嵌入,MLP 解码器从嵌入中重建基因表达。ST 可以直接对基因表达谱进行操作,或者从 SE 中获取细胞表示。当使用 SE 训练时,一个单独的 MLP 从预测嵌入中解码出扰动后的基因表达。

4.2.2. State Transition Model (ST)

ST 的核心动机是建模超出已知协变量(如细胞类型和扰动标签)的细胞异质性,以改进扰动响应预测。

数据表示与细胞集合的形成

假设我们有一个单细胞 RNA 测序数据集 D={(x(i),pi,i,bi)}i=1N\mathcal{D} = \big\{ \big( \mathbf{x}^{(i)}, p_i, \ell_i, b_i \big) \big\}_{i=1}^N,其中:

  • x(i)RG\mathbf{x}^{(i)} \in \mathbb{R}^G 表示细胞 ii 的归一化基因表达向量,GG 是基因数量。原始计数数据经过 log1p 转换。
  • pi{1,,P}{ctrl}p_i \in \{1, \ldots, P\} \cup \{\mathsf{ctrl}\} 是扰动标签(ctrl 表示控制组)。
  • i{1,,L}\ell_i \in \{1, \ldots, L\} 是生物背景或细胞系标签。
  • bi{1,,B}b_i \in \{1, \ldots, B\} 是可选的批次效应标签。

细胞集合的形成 (Formation of Cell Sets, Section 5.3.1) 细胞根据其生物背景 \ell、扰动 pp 和批次 bb 进行分组: C,p,b = {x(i)D  i=, pi=p, bi=b}. \mathcal{C}_{\ell, p, b} \ = \ \big\{ \mathbf{x}^{(i)} \in \mathcal{D} \ \big| \ \ell_i = \ell, \ p_i = p, \ b_i = b \big\}.

  • 符号解释:
    • C,p,b\mathcal{C}_{\ell, p, b}:给定细胞系 \ell、扰动 pp 和批次 bb 的细胞集合。

    • x(i)\mathbf{x}^{(i)}:细胞 ii 的基因表达向量。

    • D\mathcal{D}:整个单细胞数据集。

    • i,pi,bi\ell_i, p_i, b_i:细胞 ii 的细胞系、扰动和批次标签。

      N,p,b=C,p,bN_{\ell, p, b} = |\mathcal{C}_{\ell, p, b}| 是这个集合中的细胞数量。为了训练,将 C,p,b\mathcal{C}_{\ell, p, b} 分块成大小为 SS 的细胞集合 S,p,b(k)RS×GS_{\ell, p, b}^{(k)} \in \mathbb{R}^{S \times G},其中 k{1,,N,p,bS}k \in \{1, \dots, \lfloor \frac{N_{\ell, p, b}}{S} \rfloor \}。如果 N,p,bN_{\ell, p, b} 不能被 SS 整除,剩余的细胞会形成一个更小的集合,并通过从自身中替换采样额外的细胞来填充到大小 SS

在细胞集合上进行训练 (Training on Cell Sets, Section 5.3.2) 在训练期间,每个扰动细胞集合 S,p,b(k)S_{\ell, p, b}^{(k)} (其中 p{1,,P}{ctrl}p \in \{1, \dots, P\} \cup \{\mathsf{ctrl}\})都与一个对应的对照细胞集合配对。对照细胞集合是通过 map 函数从同一细胞系 \ell 和批次 bb 的控制细胞中抽样 SS 个细胞形成的。 map(S,p,b(k))=stack([x(i)]x(i)C,ctrl,b)RS×G \mathtt{map}(S_{\ell, p, b}^{(k)}) = \mathtt{stack}( [ \mathbf{x}^{(i)} ]_{\mathbf{x}^{(i)} \sim \mathcal{C}_{\ell, \mathrm{ctrl}, b}} ) \in \mathbb{R}^{S \times G}

  • 符号解释:
    • map(S,p,b(k))\mathtt{map}(S_{\ell, p, b}^{(k)}):给定扰动细胞集合 S,p,b(k)S_{\ell, p, b}^{(k)} 的映射对照细胞集合。

    • stack()\mathtt{stack}(\cdot):将向量堆叠成矩阵的操作。

    • x(i)C,ctrl,b\mathbf{x}^{(i)} \sim \mathcal{C}_{\ell, \mathrm{ctrl}, b}:从给定细胞系 \ell 和批次 bb 的控制细胞集合中采样的细胞。

    • RS×G\mathbb{R}^{S \times G}:维度为 SS(细胞集合大小)乘以 GG(基因数量)的实数矩阵。

      这种 map 函数有效地决定了哪些变异源被显式控制。通过条件化特定的协变量(如细胞系 \ell、批次 bb),它减少了可能混淆真实扰动信号的已知异质性来源。

为了形成 BB 个这样的集合对,{(Si,pi,bi(ki),map(Si,pi,bi(ki)))}i=1,B\{(S_{\ell_i, p_i, b_i}^{(k_i)}, \mathfrak{map}(S_{\ell_i, p_i, b_i}^{(k_i)}))\}_{i=1, \dots B},这些对可能来源于细胞系、扰动和(可选)批次的不同组合。它们被排列成以下张量: Xtarget=stack([S1,p1,b1(k1),,SB,pB,bB(kB)])RB×S×GXctrl=stack([mar(S1,p1,b1(k1)),,map(SB,pB,bB(kB))])RB×S×GZpertRB×S×Dpert(perturbation embeddings)ZbatchRB×S×Dbatch(optional batch covariates) \begin{array}{r l} & \mathbf{X}_{\mathrm{target}} = \mathsf{stack}( [ S_{\ell_1, p_1, b_1}^{(k_1)}, \ldots, S_{\ell_B, p_B, b_B}^{(k_B)} ] ) \in \mathbb{R}^{B \times S \times G} \\ & \quad \mathbf{X}_{\mathrm{ctrl}} = \mathsf{stack}( [ \mathsf{mar}(S_{\ell_1, p_1, b_1}^{(k_1)}), \ldots, \mathsf{map}(S_{\ell_B, p_B, b_B}^{(k_B)}) ] ) \in \mathbb{R}^{B \times S \times G} \\ & \quad \mathbf{Z}_{\mathrm{pert}} \in \mathbb{R}^{B \times S \times D_{\mathrm{pert}}} \quad \mathrm{(perturbation~embeddings)} \\ & \quad \mathbf{Z}_{\mathrm{batch}} \in \mathbb{R}^{B \times S \times D_{\mathrm{batch}}} \quad \mathrm{(optional~batch~covariates)} \end{array}

  • 符号解释:
    • Xtarget\mathbf{X}_{\mathrm{target}}:目标扰动细胞集合的批次张量,维度为 BB (批次大小) ×\times SS (细胞集合大小) ×\times GG (基因数量)。

    • Xctrl\mathbf{X}_{\mathrm{ctrl}}:对照细胞集合的批次张量,维度与 Xtarget\mathbf{X}_{\mathrm{target}} 相同。

    • Zpert\mathbf{Z}_{\mathrm{pert}}:扰动嵌入的批次张量,维度为 BB ×\times SS ×\times DpertD_{\mathrm{pert}}

    • Zbatch\mathbf{Z}_{\mathrm{batch}}:可选批次协变量的批次张量,维度为 BB ×\times SS ×\times DbatchD_{\mathrm{batch}}

    • DpertD_{\mathrm{pert}}:扰动嵌入的维度。对于 STATE,使用 one-hot 编码,所以 DpertD_{\mathrm{pert}} 等于数据中唯一扰动的数量。

    • DbatchD_{\mathrm{batch}}:批次嵌入的维度。对于 STATE,使用 one-hot 编码,所以 DbatchD_{\mathrm{batch}} 等于唯一批次标签的数量。

      ST 接收 Xctrl\mathbf{X}_{\mathrm{ctrl}} 作为输入,连同扰动嵌入 Zpert\mathbf{Z}_{\mathrm{pert}},并学习预测 Xtarget\mathbf{X}_{\mathrm{target}} 作为输出,从而学习将控制细胞群体转换为相应的扰动状态。

神经网络模块 (Neural Network Modules, Section 5.3.3) ST 使用专门的编码器,将细胞表达谱、扰动标签和可选的批次标签映射到共享的隐藏维度 dhd_h,作为 Transformer 的输入。

  1. 控制细胞编码器 (Control Cell Encoder) 每个 log-normalized 表达向量 x(i)RG\mathbf{x}^{(i)} \in \mathbb{R}^G 通过一个4层 MLP (fcellf_{\mathrm{cell}}) 映射到嵌入空间,使用 GELU 激活函数。这个 MLP 独立应用于整个对照张量中的每个细胞: Hcell=fcell(Xctrl)RB×S×dh \mathbf{H}_{\mathrm{cell}} = f_{\mathrm{cell}}(\mathbf{X}_{\mathrm{ctrl}}) \in \mathbb{R}^{B \times S \times d_h}

    • 符号解释:
      • Hcell\mathbf{H}_{\mathrm{cell}}:编码后的控制细胞嵌入,维度为 BB ×\times SS ×\times dhd_h
      • fcellf_{\mathrm{cell}}:4层 MLP 编码器。
      • Xctrl\mathbf{X}_{\mathrm{ctrl}}:对照细胞集合的批次张量。
      • dhd_h:共享的隐藏维度。 这将输入形状从 (B×S×G)(B \times S \times G) 转换为 (B×S×dh)(B \times S \times d_h)
  2. 扰动编码器 (Perturbation Encoder) 扰动标签被编码到相同的嵌入维度 dhd_h 中。对于 one-hot 编码的扰动,输入向量通过一个4层 MLP (fpertf_{\mathrm{pert}}) 传递,使用 GELU 激活函数: Hpert = fpert(Zpert)  RB×S×dh \mathbf{H}_{\mathrm{pert}} \ = \ f_{\mathrm{pert}}(\mathbf{Z}_{\mathrm{pert}}) \ \in \ \mathbb{R}^{B \times S \times d_h}

    • 符号解释:
      • Hpert\mathbf{H}_{\mathrm{pert}}:编码后的扰动嵌入,维度为 BB ×\times SS ×\times dhd_h
      • fpertf_{\mathrm{pert}}:4层 MLP 编码器。
      • Zpert\mathbf{Z}_{\mathrm{pert}}:扰动嵌入的批次张量。 这将输入形状从 (B×S×Dpert)(B \times S \times D_{\mathrm{pert}}) 转换为 (B×S×dh)(B \times S \times d_h)。注意,在给定批次中,同一集合内的所有细胞的扰动嵌入是相同的。如果扰动由连续特征(如分子描述符)表示,这些嵌入将直接用作 Hpert\mathbf{H}_{\mathrm{pert}},此时 dh=Dpertd_h = D_{\mathrm{pert}}
  3. 批次编码器 (Batch Encoder) 为了考虑技术批次效应,批次标签 bi{1,,B}b_i \in \{1, \ldots, B\} 被编码成维度为 dhd_h 的嵌入: Hbatch=fbatch(Zbatch)RB×S×dh \mathbf{H}_{\mathrm{batch}} = f_{\mathrm{batch}}(\mathbf{Z}_{\mathrm{batch}}) \in \mathbb{R}^{B \times S \times d_h}

    • 符号解释:
      • Hbatch\mathbf{H}_{\mathrm{batch}}:编码后的批次嵌入,维度为 BB ×\times SS ×\times dhd_h
      • fbatchf_{\mathrm{batch}}:嵌入层。
      • Zbatch\mathbf{Z}_{\mathrm{batch}}:批次协变量的批次张量。 这将输入形状从 (B×S×Dbatch)(B \times S \times D_{\mathrm{batch}}) 转换为 (B×S×dh)(B \times S \times d_h)
  4. Transformer 输入和输出 (Transformer Inputs and Outputs) ST 的最终输入通过将控制细胞嵌入与扰动和批次嵌入相加来构建: H=Hcell+Hpert+Hbatch \mathbf{H} = \mathbf{H}_{\mathrm{cell}} + \mathbf{H}_{\mathrm{pert}} + \mathbf{H}_{\mathrm{batch}}

    • 符号解释:
      • H\mathbf{H}Transformer 的最终输入,是控制细胞、扰动和批次嵌入的复合表示。 这个复合表示被传递给 Transformer 主干网络 fSTf_{\mathrm{ST}},以建模细胞集合中的扰动效应。输出计算如下: O=H+fST(H) \mathbf{O} = \mathbf{H} + f_{\mathrm{ST}}(\mathbf{H})
    • 符号解释:
      • O\mathbf{O}:最终输出,维度为 RB×S×dh\mathbb{R}^{B \times S \times d_h}
      • fSTf_{\mathrm{ST}}Transformer 主干网络。 这种公式化鼓励 Transformer fSTf_{\mathrm{ST}} 将扰动效应学习为输入表示 H\mathbf{H} 的残差。
  5. 基因重构头 (Gene Reconstruction Head) 当直接在表达空间中操作时,基因重构头将 Transformer 的输出 O\mathbf{O} 映射回基因表达空间。这是通过一个线性投影层完成的,该层独立应用于批次中每个 token(细胞)的 dhd_h 维隐藏表示。对于批次 bbTransformer 输出 O(b)RS×dh\mathbf{O}^{(b)} \in \mathbb{R}^{S \times d_h},预测的目标基因表达 X^target(b)RS×G\hat{\mathbf{X}}_{\mathrm{target}}^{(b)} \in \mathbb{R}^{S \times G} 如下给出: X^target(b)=frecon(O(b))=O(b)Wrecon+brecon \hat{\mathbf{X}}_{\mathrm{target}}^{(b)} = f_{\mathrm{recon}}(\mathbf{O}^{(b)}) = \mathbf{O}^{(b)}\mathbf{W}_{\mathrm{recon}} + \mathbf{b}_{\mathrm{recon}}

    • 符号解释:
      • X^target(b)\hat{\mathbf{X}}_{\mathrm{target}}^{(b)}:批次 bb 的预测目标基因表达。
      • freconf_{\mathrm{recon}}:基因重构头,一个线性投影层。
      • O(b)\mathbf{O}^{(b)}:批次 bbTransformer 输出。
      • WreconRdh×G\mathbf{W}_{\mathrm{recon}} \in \mathbb{R}^{d_h \times G}:可学习的权重参数。
      • breconRG\mathbf{b}_{\mathrm{recon}} \in \mathbb{R}^G: 可学习的偏置参数。 此操作将表示转换为重构的 log-transformed 基因表达值。

用最大均值差异 (MMD) 学习扰动效应 (Learning Perturbation Effects with Maximum Mean Discrepancy, Section 5.3.4) ST 模型的训练目标是最小化预测的转录组响应 X^target\hat{\mathbf{X}}_{\mathrm{target}} 与观测到的转录组响应 Xtarget\mathbf{X}_{\mathrm{target}} 之间的差异。这种差异通过最大均值差异(MMD)量化,MMD 是一种基于再生核希尔伯特空间(RKHS)中嵌入的统计距离度量。

  1. MMD平方的计算 对于每个小批量(mini-batch)元素 bb,我们考虑由 SS 个预测细胞表达向量组成的集合 X^target(b)={x^(i)RG}i=1S\hat{\mathbf{X}}_{\mathrm{target}}^{(b)} = \{ \hat{\mathbf{x}}^{(i)} \in \mathbb{R}^G \}_{i=1}^SSS 个观测细胞表达向量组成的集合 Xtarget(b)={x(i)RG}i=1S\mathbf{X}_{\mathrm{target}}^{(b)} = \{ \mathbf{x}^{(i)} \in \mathbb{R}^G \}_{i=1}^SMMD 衡量这些由有限向量集合隐式定义的两个分布之间的距离: MMD2(X^target(b),Xtarget(b))=1S2i=1Sj=1S[k(x^(i),x^(j))+k(x(i),x(j))2k(x^(i),x(j))] \mathrm{MMD}^2 \bigl( \hat{\mathbf{X}}_{\mathrm{target}}^{(b)}, \mathbf{X}_{\mathrm{target}}^{(b)} \bigr) = \frac{1}{S^2} \sum_{i=1}^{S} \sum_{j=1}^{S} \Big[ k(\hat{\mathbf{x}}^{(i)}, \hat{\mathbf{x}}^{(j)}) + k(\mathbf{x}^{(i)}, \mathbf{x}^{(j)}) - 2k(\hat{\mathbf{x}}^{(i)}, \mathbf{x}^{(j)}) \Big]

    • 符号解释:
      • MMD2(,)\mathrm{MMD}^2(\cdot, \cdot):平方的最大均值差异。
      • X^target(b)\hat{\mathbf{X}}_{\mathrm{target}}^{(b)}:批次 bb 中预测的细胞表达集合。
      • Xtarget(b)\mathbf{X}_{\mathrm{target}}^{(b)}:批次 bb 中观测的细胞表达集合。
      • SS:细胞集合的大小。
      • k(,)k(\cdot, \cdot):核函数(kernel function)。
      • x^(i),x^(j)\hat{\mathbf{x}}^{(i)}, \hat{\mathbf{x}}^{(j)}:预测集合中的细胞表达向量。
      • x(i),x(j)\mathbf{x}^{(i)}, \mathbf{x}^{(j)}:观测集合中的细胞表达向量。 等式中的三项分别对应于:(1) 预测集合内部的相似性,(2) 观测集合内部的相似性,(3) 预测集合与观测集合之间的交叉相似性。
  2. 能量距离核 (Energy Distance Kernel) 本文使用能量距离核: k(u,v)=uv2 k(\mathbf{u}, \mathbf{v}) = -\|\mathbf{u} - \mathbf{v}\|_2

    • 符号解释:
      • u,v\mathbf{u}, \mathbf{v}:两个输入向量。
      • 2\|\cdot\|_2: 欧几里得范数(L2范数)。
  3. 批次平均MMD损失 (Batch-averaged MMD Loss) 对于包含 BB 个细胞集合的训练小批量,批次平均 MMD 损失定义为这些 MMD2\mathrm{MMD}^2 值的平均值: LMMD(X^target,Xtarget)=1Bb=1BMMD2(X^target(b),Xtarget(b)) \mathcal{L}_{\mathrm{MMD}} (\hat{\mathbf{X}}_{\mathrm{target}}, \mathbf{X}_{\mathrm{target}}) = \frac{1}{B} \sum_{b=1}^{B} \mathrm{MMD}^2 \big( \hat{\mathbf{X}}_{\mathrm{target}}^{(b)}, \mathbf{X}_{\mathrm{target}}^{(b)} \big)

    • 符号解释:
      • LMMD\mathcal{L}_{\mathrm{MMD}}:批次平均 MMD 损失。
      • BB:批次大小。 最小化此损失鼓励模型生成具有与扰动标签一致的统计特性,并与观测细胞集合匹配的扰动细胞表达集合。
  4. 总损失 (Total Loss) Ltotal=LMMD(X^target,Xtarget) \mathcal{L}_{\mathrm{total}} = \mathcal{L}_{\mathrm{MMD}} (\hat{\mathbf{X}}_{\mathrm{target}}, \mathbf{X}_{\mathrm{target}})

在嵌入空间中训练ST (Training ST in Embedding Spaces, Section 5.3.5) ST 具有在基因表达空间或指定嵌入空间中训练的灵活性。当在嵌入空间中训练时,架构会进行修改,以包含一个额外的表达解码器。

  • EE 为嵌入空间的维度,通常 EGE \ll G

  • 张量 Xtarget\mathbf{X}_{\mathrm{target}}Xctrl\mathbf{X}_{\mathrm{ctrl}} 现在表示为 Xtargetemb\mathbf{X}_{\mathrm{target}}^{\mathrm{emb}}Xctrlemb\mathbf{X}_{\mathrm{ctrl}}^{\mathrm{emb}},维度为 B×S×EB \times S \times E

  • fcellf_{\mathrm{cell}} 被修改为将输入形状从 (B×S×E)(B \times S \times E) 转换为 (B×S×dh)(B \times S \times d_h)

  • freconf_{\mathrm{recon}} 被修改为将输出形状从 (B×S×dh)(B \times S \times d_h) 转换为 (B×S×E)(B \times S \times E)。这个输出记作 X^targetemb\hat{\mathbf{X}}_{\mathrm{target}}^{\mathrm{emb}}

  • 其他编码器和 Transformer 本身保持不变。

    为了恢复目标细胞的原始基因表达 Xtarget\mathbf{X}_{\mathrm{target}},训练了一个额外的解码器头 fdecodef_{\mathrm{decode}}。这是一个带 dropout 的多层 MLP,将嵌入空间映射回完整的基因表达空间: X^target=fdecode(X^targetemb) \hat{\mathbf{X}}_{\mathrm{target}} = f_{\mathrm{decode}}(\hat{\mathbf{X}}_{\mathrm{target}}^{\mathrm{emb}})

  • 符号解释:

    • X^target\hat{\mathbf{X}}_{\mathrm{target}}:从嵌入空间解码得到的预测基因表达。
    • fdecodef_{\mathrm{decode}}:额外的解码器头(多层 MLP)。
    • X^targetemb\hat{\mathbf{X}}_{\mathrm{target}}^{\mathrm{emb}}:预测的目标嵌入。 这将预测的嵌入从 (B×S×E)(B \times S \times E) 转换为 (B×S×G)(B \times S \times G),以恢复基因表达谱。

具体来说,模型使用嵌入空间和基因表达空间中 MMD 损失的加权组合进行训练: Ltotal=LMMD(X^targetemb,Xtargetemb)+0.1LMMD(X^target,Xtarget) \mathcal{L}_{\mathrm{total}} = \mathcal{L}_{\mathrm{MMD}} (\hat{\mathbf{X}}_{\mathrm{target}}^{\mathrm{emb}}, \mathbf{X}_{\mathrm{target}}^{\mathrm{emb}}) + 0.1 \cdot \mathcal{L}_{\mathrm{MMD}} (\hat{\mathbf{X}}_{\mathrm{target}}, \mathbf{X}_{\mathrm{target}})

  • 符号解释:
    • Ltotal\mathcal{L}_{\mathrm{total}}:总损失。
    • LMMD(X^targetemb,Xtargetemb)\mathcal{L}_{\mathrm{MMD}}(\hat{\mathbf{X}}_{\mathrm{target}}^{\mathrm{emb}}, \mathbf{X}_{\mathrm{target}}^{\mathrm{emb}}):嵌入空间中的 MMD 损失。
    • LMMD(X^target,Xtarget)\mathcal{L}_{\mathrm{MMD}}(\hat{\mathbf{X}}_{\mathrm{target}}, \mathbf{X}_{\mathrm{target}}):基因表达空间中的 MMD 损失。 表达损失权重为0.1,以平衡两个项,避免压倒嵌入空间中的主要目标。

4.2.3. State Embedding Model (SE)

SE 模型是一个自监督模型,通过基因表达预测目标进行训练,以从单细胞 RNA 测序数据中学习细胞表示。SE 生成的嵌入作为 ST 的输入,从而实现跨数据集和生物背景的更鲁棒的迁移。

基因表示 (Gene Representation via Protein Language Models, Section 5.4.1) SE 模型通过预训练的蛋白质语言模型(如 ESM-2 [60])来编码基因特征。首先计算蛋白质编码转录本的嵌入(通过对每个转录本的氨基酸嵌入取平均),然后对基因中的所有转录本取平均。这种特征化捕获了基因之间的进化和功能关系。 每个基因嵌入 gjR5120\mathbf{g}_j \in \mathbb{R}^{5120} 通过一个可学习的编码器投影到模型的嵌入维度 hhg~j=SiLU(LayerNorm(gjWg+bg)) \tilde{\mathbf{g}}_j = \mathrm{SiLU}(\mathrm{LayerNorm}(\mathbf{g}_j \mathbf{W}_g + \mathbf{b}_g))

  • 符号解释:
    • g~j\tilde{\mathbf{g}}_j:投影后的基因嵌入,维度为 Rh\mathbb{R}^h
    • SiLU()\mathrm{SiLU}(\cdot):Sigmoid 线性单元激活函数 [63]。
    • LayerNorm()\mathrm{LayerNorm}(\cdot):层归一化。
    • gj\mathbf{g}_j:原始基因嵌入,维度为 R5120\mathbb{R}^{5120}
    • WgR5120×h\mathbf{W}_g \in \mathbb{R}^{5120 \times h}:可学习的权重参数。
    • bgRh\mathbf{b}_g \in \mathbb{R}^h: 可学习的偏置参数。

细胞表示 (Cell Representation, Section 5.4.2) 每个细胞 ii 被表示为其最高表达的 L=2048L = 2048 个基因的序列。这个“表达集合”通过两个特殊词元进行扩充: c~(i)=[zcls,g~1(i),g~2(i),\hdots,g~L(i),zds]R(L+2)×h \tilde{\mathbf{c}}^{(i)} = [\mathbf{z}_{\mathrm{cls}}, \tilde{\mathbf{g}}_1^{(i)}, \tilde{\mathbf{g}}_2^{(i)}, \hdots, \tilde{\mathbf{g}}_L^{(i)}, \mathbf{z}_{\mathrm{ds}}] \in \mathbb{R}^{(L+2) \times h}

  • 符号解释:
    • c~(i)\tilde{\mathbf{c}}^{(i)}:细胞 ii 的增强表达集合序列。
    • zclsRh\mathbf{z}_{\mathrm{cls}} \in \mathbb{R}^h:可学习的分类词元(classification token),用于聚合细胞级别信息。
    • g~(i)\tilde{\mathbf{g}}_\ell^{(i)}:细胞 ii 中第 \ell 个最高表达基因的投影嵌入。
    • zdsRh\mathbf{z}_{\mathrm{ds}} \in \mathbb{R}^h:可学习的数据集词元(dataset token),有助于解耦数据集特定的效应。 如果细胞表达的基因少于 LL 个,表达集合会通过从非表达基因池中随机采样来填充到长度 LL

表达感知嵌入 (Expression-Aware Embeddings, Section 5.4.3) 为了将表达值纳入模型,SE 采用了一种受软分箱(soft binning)启发的表达嵌入方案。对于细胞 ii 表达集合中第 \ell 个最高表达基因 j(i)j_{\ell}^{(i)},其表达值 xj(i)(i)\mathbf{x}_{j_{\ell}^{(i)}}^{(i)} 在细胞 ii 中被计算为一个软分箱分配: α(i)=Softmax(MLPcount(xj(i)(i)))R10e(i)=k=110α,k(i)bk \begin{array}{r l} & \boldsymbol{\alpha}_{\ell}^{(i)} = \mathrm{Softmax}(\mathrm{MLP}_{\mathrm{count}}(\mathbf{x}_{j_{\ell}^{(i)}}^{(i)})) \in \mathbb{R}^{10} \\ & \mathbf{e}_{\ell}^{(i)} = \displaystyle \sum_{k=1}^{10} \boldsymbol{\alpha}_{\ell,k}^{(i)} \mathbf{b}_k \end{array}

  • 符号解释:
    • α(i)\boldsymbol{\alpha}_{\ell}^{(i)}:基因 j(i)j_{\ell}^{(i)} 表达值的软分箱分配向量。
    • MLPcount\mathrm{MLP}_{\mathrm{count}}:将实数值输入(表达值)映射到 10 维的 MLP(两个线性层,维度 1512101 \to 512 \to 10,使用 LeakyReLU 激活)。
    • xj(i)(i)\mathbf{x}_{j_{\ell}^{(i)}}^{(i)}:细胞 ii 中基因 j(i)j_{\ell}^{(i)} 的表达值。
    • e(i)\mathbf{e}_{\ell}^{(i)}:基因 j(i)j_{\ell}^{(i)} 的表达嵌入。
    • {bk}k=110Rh\{\mathbf{b}_k\}_{k=1}^{10} \in \mathbb{R}^h:可学习的 bin 嵌入。 这些表达嵌入 e(i)\mathbf{e}_{\ell}^{(i)} 被添加到对应的基因标识嵌入 g~(i)\tilde{\mathbf{g}}_{\ell}^{(i)}g(i)=g~(i)+e(i) \mathbf{g}_{\ell}^{(i)} = \tilde{\mathbf{g}}_{\ell}^{(i)} + \mathbf{e}_{\ell}^{(i)}
    • 符号解释:
      • g(i)\mathbf{g}_{\ell}^{(i)}:最终的表达感知基因嵌入。

Transformer 编码 (Transformer Encoding, Section 5.4.4) 由表达感知基因嵌入和特殊词元组成的输入表达集合,通过 Transformer 编码器 fSEf_{\mathrm{SE}} 传递: E(i)=fSE([zcls,g1(i),g2(i),,gL(i),zds])R(L+2)×h \mathbf{E}^{(i)} = f_{\mathrm{SE}}([\mathbf{z}_{\mathrm{cls}}, \mathbf{g}_1^{(i)}, \mathbf{g}_2^{(i)}, \dots, \mathbf{g}_L^{(i)}, \mathbf{z}_{\mathrm{ds}}]) \in \mathbb{R}^{(L+2) \times h}

  • 符号解释:
    • E(i)\mathbf{E}^{(i)}:细胞 ii 的上下文嵌入序列。
    • fSEf_{\mathrm{SE}}Transformer 编码器。
    • g(i)\mathbf{g}_\ell^{(i)}:如上所示的表达感知基因嵌入。 输出是一个上下文嵌入序列。细胞嵌入从 [CLS] 词元(位置0)中提取并归一化: ecls(i)=LayerNorm(E0(i))Rh \mathbf{e}_{\mathrm{cls}}^{(i)} = \mathrm{LayerNorm}(\mathbf{E}_0^{(i)}) \in \mathbb{R}^h
    • 符号解释:
      • ecls(i)\mathbf{e}_{\mathrm{cls}}^{(i)}:细胞 ii 的分类(CLS)词元嵌入,作为细胞转录组状态的摘要表示。 同样,数据集表示从 [DS] 词元(位置 L+1L+1)中提取: eds(i)=LayerNorm(EL+1(i))Rh \mathbf{e}_{\mathrm{ds}}^{(i)} = \mathrm{LayerNorm}(\mathbf{E}_{L+1}^{(i)}) \in \mathbb{R}^h
    • 符号解释:
      • eds(i)\mathbf{e}_{\mathrm{ds}}^{(i)}:细胞 ii 的数据集词元嵌入,用于捕捉和解释数据集特定的效应。 最终的嵌入是这两个量的串联: zcell(i)=[ecls(i),fproj(eds(i))]Rh+10 \mathbf{z}_{\mathrm{cell}}^{(i)} = [\mathbf{e}_{\mathrm{cls}}^{(i)}, f_{\mathrm{proj}}(\mathbf{e}_{\mathrm{ds}}^{(i)})] \in \mathbb{R}^{h+10}
    • 符号解释:
      • zcell(i)\mathbf{z}_{\mathrm{cell}}^{(i)}:细胞 ii 的最终细胞嵌入,作为 ST 中个体细胞的输入表示。
      • fproj(eds(i))R10f_{\mathrm{proj}}(\mathbf{e}_{\mathrm{ds}}^{(i)}) \in \mathbb{R}^{10}:一个将数据集词元嵌入投影到10维的函数。

预训练目标 (Pretraining Objectives, Section 5.4.5) SE 通过自监督学习框架进行训练,包含两个互补的目标:(1) 基因表达预测任务,和 (2) 辅助数据集分类任务,有助于将技术批次效应与生物信号解耦。

  1. 基因表达预测 (Gene Expression Prediction) 模型接收完整的输入细胞表达集合(如 Eq. 23 所述),并任务预测每个细胞中选定1280个基因的表达值。目标基因来自三类:

    • P(i)\mathcal{P}^{(i)}:512个高表达基因(来自细胞 ii 表达集合中的前 LL 个基因)。
    • N(i)\mathcal{N}^{(i)}:512个未表达基因(从细胞 ii 的前 LL 个基因之外的基因池中随机采样)。
    • R\mathcal{R}:256个从完整基因集中随机采样的基因,批次中所有细胞共享。 这导致每个细胞需要预测1280个基因的表达值。

    表达预测解码器 (Expression Prediction Decoder) 使用 MLP 解码器 MLPdec\mathrm{MLP}_{\mathrm{dec}} 结合多种信息源进行预测: x^j(i)=MLPdec([zcell(i);g~j;r(i)]) \hat{\mathbf{x}}_j^{(i)} = \mathrm{MLP}_{\mathrm{dec}} \big( [\mathbf{z}_{\mathrm{cell}}^{(i)}; \tilde{\mathbf{g}}_j; r^{(i)}] \big)

    • 符号解释:
      • x^j(i)\hat{\mathbf{x}}_j^{(i)}:预测的细胞 ii 中基因 jj 的表达值。
      • MLPdec\mathrm{MLP}_{\mathrm{dec}}:表达预测解码器。
      • zcell(i)Rh+10\mathbf{z}_{\mathrm{cell}}^{(i)} \in \mathbb{R}^{h+10}:学习到的细胞嵌入(Eq. 30)。
      • g~jRh\tilde{\mathbf{g}}_j \in \mathbb{R}^h:目标基因 jj 的嵌入(Eq. 22)。
      • r(i)Rr^{(i)} \in \mathbb{R}:一个标量读深度指示器,计算为细胞 ii 输入表达集合中已表达基因的平均 log 表达。 这些被连接起来并通过 MLP 传递,该 MLP 包含两个跳跃连接块,后跟一个线性输出层,预测目标基因的 log 表达。

    对于训练批次中的每个细胞 ii,设 Y^(i)=[x^j(i)]jP(i)N(i)RR1×1280\hat{\mathbf{Y}}^{(i)} = [ \hat{\mathbf{x}}_j^{(i)} ]_{j \in \mathcal{P}^{(i)} \cup \mathcal{N}^{(i)} \cup \mathcal{R}} \in \mathbb{R}^{1 \times 1280} 表示预测的表达值行向量,Y(i)=[xj(i)]jP(i)N(i)RR1×1280\mathbf{Y}^{(i)} = [ \mathbf{x}_j^{(i)} ]_{j \in \mathcal{P}^{(i)} \cup \mathcal{N}^{(i)} \cup \mathcal{R}} \in \mathbb{R}^{1 \times 1280} 表示对应的真实表达值行向量。张量在批次中的细胞间堆叠:Y=stack([Y(i)]i=1B)\mathbf{Y} = \mathsf{stack}( [ \mathbf{Y}^{(i)} ]_{i=1}^B )Y^=stack([Y^(i)]i=1B)\hat{\mathbf{Y}} = \mathsf{stack}( [ \hat{\mathbf{Y}}^{(i)} ]_{i=1}^B ),形状为 (B×1×1280)(B \times 1 \times 1280)

    基因级别损失 (Gene-level Loss) Lgene=1Bb=1BY^(b)Y(b)2 \mathcal{L}_{\mathrm{gene}} = \frac{1}{B} \sum_{b=1}^{B} \|\hat{\mathbf{Y}}^{(b)} - \mathbf{Y}^{(b)}\|_2

    • 符号解释:
      • Lgene\mathcal{L}_{\mathrm{gene}}:基因级别损失,衡量每个细胞内预测与真实基因表达模式的相似性。

        为了捕捉小批量中细胞间基因表达的变异,还计算了一个细胞级别损失,使用共享子集 R\mathcal{R}。设 S^(i)=stack([x^j(i)]jR)R256\hat{\mathbf{S}}^{(i)} = \mathsf{stack}( [ \hat{\mathbf{x}}_j^{(i)} ]_{j \in \mathcal{R}} ) \in \mathbb{R}^{256}S(i)=stack([xj(i)]jR)R256\mathbf{S}^{(i)} = \mathsf{stack}( [ \mathbf{x}_j^{(i)} ]_{j \in \mathcal{R}} ) \in \mathbb{R}^{256} 表示细胞 ii 中共享基因 R\mathcal{R} 的预测和真实表达值。这些张量在批次中堆叠:S=transpose(stack([S(i)]i=1B))\mathbf{S}' = \mathtt{transpose}(\mathsf{stack}( [ \mathbf{S}^{(i)} ]_{i=1}^B ))S^=transpose(stack([S^(i)]i=1B))\hat{\mathbf{S}}' = \mathtt{transpose}(\mathsf{stack}( [ \hat{\mathbf{S}}^{(i)} ]_{i=1}^B )),形状为 (R×1×B)(|\mathcal{R}| \times 1 \times B)。然后计算每个基因在批次中所有细胞的连接预测值和目标值之间的距离。

    细胞级别损失 (Cell-level Loss) Lcell=1Rr=1RS^(r)S(r)2 \mathcal{L}_{\mathrm{cell}} = \frac{1}{|\mathcal{R}|} \sum_{r=1}^{|\mathcal{R}|} \|\hat{\mathbf{S}}^{\prime(r)} - \mathbf{S}^{\prime(r)}\|_2

    • 符号解释:
      • Lcell\mathcal{L}_{\mathrm{cell}}:细胞级别损失,衡量共享基因在批次中细胞间表达模式的一致性。

    总表达损失 (Total Expression Loss) 最终的表达预测训练损失结合了这两个轴: Lexpression=λ1Lgene+λ2Lcell \mathcal{L}_{\mathrm{expression}} = \lambda_1 \mathcal{L}_{\mathrm{gene}} + \lambda_2 \mathcal{L}_{\mathrm{cell}}

    • 符号解释:
      • λ1,λ2\lambda_1, \lambda_2:权重系数。 这种双轴重构损失捕捉了每个细胞内的基因重构保真度以及共享基因在细胞间的表达模式一致性。
  2. 数据集分类建模 (Dataset Classification Modeling) 为了将数据集特有的技术效应与生物变异解耦,引入了一个辅助数据集预测任务。使用 [DS] 词元嵌入,模型预测数据集的来源: d^(i)=MLPdataset(eds(i)) \hat{d}^{(i)} = \mathrm{MLP}_{\mathrm{dataset}}(\mathbf{e}_{\mathrm{ds}}^{(i)})

    • 符号解释:
      • d^(i)\hat{d}^{(i)}:预测的细胞 ii 的数据集标签。
      • MLPdataset\mathrm{MLP}_{\mathrm{dataset}}:数据集分类 MLP
      • eds(i)\mathbf{e}_{\mathrm{ds}}^{(i)}:数据集词元嵌入。

    数据集损失 (Dataset Loss) Ldataset=1Bb=1BCrossEntropy(d^(b),d(b)) \mathcal{L}_{\mathrm{dataset}} = \frac{1}{B} \sum_{b=1}^{B} \mathrm{CrossEntropy}(\hat{d}^{(b)}, d^{(b)})

    • 符号解释:
      • Ldataset\mathcal{L}_{\mathrm{dataset}}:数据集损失,使用交叉熵损失。
      • d(b)d^{(b)}:批次中细胞 bb 的真实数据集标签。
      • d^(b)\hat{d}^{(b)}:批次中细胞 bb 的预测数据集标签。 这个辅助目标鼓励模型将相关信息汇聚到这个词元位置,将其与真实的生物信号解耦。
  3. SE总损失 (Total Loss) SE 模型使用两个损失的组合进行训练: L=Lexpression+Ldataset \mathcal{L} = \mathcal{L}_{\mathrm{expression}} + \mathcal{L}_{\mathrm{dataset}}

4.2.4. ST与最优传输 (Optimal Transport, OT) 的理论关联

论文的理论部分(Section 6)深入探讨了 ST 在学习细胞分布间最优传输映射方面的渐近行为和解族。

ST渐近行为和解族 (ST Asymptotic Behavior and Solution Family, Section 6.1) ST 通过学习一个将未受扰细胞分布与受扰细胞分布对齐的转换来执行与神经最优传输(Neural OT)相关的任务。尽管 ST 未显式解决 OT 问题,但当细胞集合大小 SS 趋于无穷大时,它能够处理整个分布的信息。 在渐近设置下,ST 的解族包含细胞分布之间唯一的连续最优传输映射,前提是分布满足正则性假设。

  1. 引理1:经验MMD (empirical MMD) 为零意味着分布匹配 (Lemma 1) 当分布 D^pert,Dpert\hat{\mathcal{D}}_{\mathrm{pert}}, \mathcal{D}_{\mathrm{pert}} 的支持集有界且细胞集合大小 SS \to \infty 时,如果经验 MMD 为零,则两个分布相等(概率为1)。反之亦然。 经验 MMD 定义为: MMD^2(X^pert(b),Xpert(b))=1S2i=1Sj=1Sk(x^(i),x^(j))+k(x(i),x(j))2k(x^(i),x(j)) \widehat{\mathrm{MMD}}^2 (\hat{\mathbf{X}}_{\mathrm{pert}}^{(b)}, \mathbf{X}_{\mathrm{pert}}^{(b)}) = \frac{1}{S^2} \sum_{i=1}^{S} \sum_{j=1}^{S} k(\hat{\mathbf{x}}^{(i)}, \hat{\mathbf{x}}^{(j)}) + k(\mathbf{x}^{(i)}, \mathbf{x}^{(j)}) - 2k(\hat{\mathbf{x}}^{(i)}, \mathbf{x}^{(j)}) 理论 MMD 定义为: MMD2(D^pert,Dpert)=Ex,xD^pert[k(x,x)]+Ey,yDpert[k(y,y)]2ExD^pert,yDpert[k(x,y)] \mathrm{MMD}^2 (\hat{\mathcal{D}}_{\mathrm{pert}}, \mathcal{D}_{\mathrm{pert}}) = \mathbb{E}_{x, x' \sim \hat{\mathcal{D}}_{\mathrm{pert}}} [k(x, x')] + \mathbb{E}_{y, y' \sim \mathcal{D}_{\mathrm{pert}}} [k(y, y')] - 2\mathbb{E}_{x \sim \hat{\mathcal{D}}_{\mathrm{pert}}, y \sim \mathcal{D}_{\mathrm{pert}}} [k(x, y)] 对于能量核 k(x,y)=xy2k(x, y) = -\|x - y\|_2,理论 MMD 与能量距离 D2D^2 一致: MMD2(D^pert,Dpert)D2(D^pert,Dpert) \mathrm{MMD}^2 (\hat{\mathcal{D}}_{\mathrm{pert}}, \mathcal{D}_{\mathrm{pert}}) \equiv D^2 (\hat{\mathcal{D}}_{\mathrm{pert}}, \mathcal{D}_{\mathrm{pert}}) 能量距离为零意味着分布相等。 ST 的解族定义为: F^={FCMMD^(X^pert(b),Xpert(b))=0 and X^pert,s(b)=F(X^ctrl,s(b))},F={FCF(Dctrl)=Dpert}. \begin{array}{r l} & \hat{\mathcal{F}} = \{F \in \mathbb{C} | \widehat{\mathrm{MMD}} (\hat{\mathbf{X}}_{\mathrm{pert}}^{(b)}, \mathbf{X}_{\mathrm{pert}}^{(b)}) = 0 \mathrm{~and~} \hat{\mathbf{X}}_{\mathrm{pert},s}^{(b)} = F (\hat{\mathbf{X}}_{\mathrm{ctrl},s}^{(b)}) \}, \\ & \mathcal{F} = \{F \in \mathbb{C} | F (\mathcal{D}_{\mathrm{ctrl}}) = \mathcal{D}_{\mathrm{pert}} \}. \end{array} 引理1表明 F^=F\hat{\mathcal{F}} = \mathcal{F} 概率为1。

  2. 定理2:最优传输映射在STATE的解族内 (Theorem 2) 假设分布 Dctrl,Dpert\mathcal{D}_{\mathrm{ctrl}}, \mathcal{D}_{\mathrm{pert}} 的密度是绝对连续和有界的,且支持集是严格凸且紧凑的,具有 C2\mathbb{C}^2 边界。那么,与平方距离成本 c(x,y)=xy2c(x, y) = \|x - y\|^2 相关的从 Dctrl\mathcal{D}_{\mathrm{ctrl}}Dpert\mathcal{D}_{\mathrm{pert}} 的连续最优传输映射 TT 满足 TF,F^T \in \mathcal{F}, \hat{\mathcal{F}} 的概率为1。 证明: 在这些正则条件下,根据 Caffarelli 定理 [82],存在从 Dctrl\mathcal{D}_{\mathrm{ctrl}}Dpert\mathcal{D}_{\mathrm{pert}} 的唯一连续可微的最优传输映射 TC1T \in \mathbb{C}^1。因此,根据定义,TFT \in \mathcal{F}。再根据引理1,F=F^\mathcal{F} = \hat{\mathcal{F}} 的概率为1,所以 TF^T \in \hat{\mathcal{F}}

  3. 定理3:受约束的ST模型用于唯一的OT映射 (Theorem 3) 在与定理2相同的假设下,考虑受约束的解族: F^={FC1MMD^(X^pert(b),Xpert(b))=0 and X^pert,s(b)=F(X^ctrl,s(b));JF={Fiαj}ij is symmetric and semipositive definite}. \begin{array}{r l} & \hat{\mathcal{F}}^* = \{F \in \mathbb{C}^1 | \widehat{\mathrm{MMD}} (\hat{\mathbf{X}}_{\mathrm{pert}}^{(b)}, \mathbf{X}_{\mathrm{pert}}^{(b)}) = 0 \mathrm{~and~} \hat{\mathbf{X}}_{\mathrm{pert},s}^{(b)} = F (\hat{\mathbf{X}}_{\mathrm{ctrl},s}^{(b)}) ; \\ & \qquad J_{\mathbf{F}} = \left\{ \frac{\partial \mathbf{F}_i}{\alpha_j} \right\}_{ij} \mathrm{~is~symmetric~and~semi-positive~definite} \}. \end{array} 那么 F^\hat{\mathcal{F}}^* 以概率1唯一包含连续 OT 映射 TTF^={T}\hat{\mathcal{F}}^* = \{T\}证明: 根据 Brenier 定理 [83],如果 T=ψT = \nabla \psiψ\psi 是凸函数),则 TT 是关于二次成本的唯一连续最优传输映射。由于 Caffarelli 定理已证明 TC1T \in \mathbb{C}^1,此条件等价于 TTJacobian 是对称且半正定的。结合引理1,TF^T \in \hat{\mathcal{F}}^* 的概率为1,且 TT 的唯一性意味着 F^={T}\hat{\mathcal{F}}^* = \{T\}

这些理论分析表明,在特定条件下,ST 的解族可以包含最优传输映射。通过在模型目标函数中施加对 Jacobian 的额外约束,理论上可以强制模型学习到唯一的 OT 映射。


5. 实验设置

5.1. 数据集

实验使用了多个大规模单细胞扰动数据集和观测数据集。

用于 ST 训练的数据集 (Datasets Used for ST Training, Section 5.5.1) 以下是原文 Table 1 的结果:

Dataset # of Cells # of Perturbations # of Contexts
Replogle-Nadig 624,158 1,677 4
Jiang 234,845 24 30
Srivatsan 762,795 189 3
Mcfaline-Figueroa 354,758 122 3
Tahoe-100M 100,648,790 1138 50
Parse-PBMC 9,697,974 90 12/18 (Donors)/(Cell Types)
  • 来源与特点:
    • Tahoe-100M 数据集 [30]:一个包含超过1亿细胞的大规模药物扰动数据集。
    • Replogle-Nadig 数据集 [4, 43]:遗传扰动数据集。
    • Parse-PBMC 数据集 [42]:包含约970万 PBMC 细胞的细胞因子信号扰动数据集。
    • Jiang 数据集 [64]:未详细说明类型,但用于 ST 训练。
    • McFaline 数据集 [65]:未详细说明类型,但用于 ST 训练。
    • Srivatsan 数据集 [41]:未详细说明类型,但用于 ST 训练。
  • 预处理:
    • 所有数据集都过滤为保留19,790个人类蛋白质编码 Ensembl 基因。
    • 归一化到总 UMI 深度为10,000。
    • 原始计数数据使用 scanpy.pp.log1p 进行 log 转换。
    • 对于高度可变基因(HVGs)的分析,每个数据集使用 scanpy.pp.highly_variable_genes 识别前2,000个 HVGs。这些 HVGslog 转换表达值用作基因级别特征。
    • 细胞的 PCA 嵌入使用 scanpy.pp.pca 计算。
  • 遗传扰动数据集的额外预处理 (Section 5.5.2):
    • 扰动级别过滤: 只保留平均敲低效率达到最小阈值(残余表达 0.30\leq 0.30)的扰动(对照组除外)。
    • 细胞级别过滤: 在选定的扰动中,只保留个体敲低阈值更严格(残余表达 0.50\leq 0.50)的细胞。
    • 最小细胞计数: 剔除细胞数量少于30个的扰动,但始终保留对照细胞。

用于 SE 训练的数据集 (Datasets Used for SE Training, Section 5.5.3) 以下是原文 Table 2 的结果:

Dataset # Training Cells # Validation Cells
Arc scBaseCount [32] 71,676,369 4,137,674
CZ CellXGene [31] 59,233,790 6,500,519
Tahoe-100M [30] 36,157,383 2,780,587
  • SE 模型在总计1.67亿个人类细胞上进行了训练,这些细胞来自 Arc scBaseCount [32]、CZ CELLxGENE [31] 和 Tahoe-100M [30] 数据集。
  • 为了避免在上下文泛化基准测试中数据泄露,Tahoe-100M 数据集中用于训练 SE 的细胞系与评估 ST 的5个保留细胞系是分开的。
  • 预处理: scBaseCount 数据过滤掉至少1,000个非零表达和每个细胞少于2,000个 UMI 的细胞。一部分 AnnData 文件被保留用于计算验证损失。

5.2. 评估指标

对扰动预测模型进行评估是其性能验证的关键。CELL-EVAL 框架用于评估模型。

5.2.1. 扰动区分分数 (Perturbation Discrimination Score, PDiscNorm)

  • 概念定义: 这个指标衡量模型区分不同扰动效应的能力。它通过比较预测的扰动伪批量(pseudobulk)与所有真实扰动伪批量之间的距离来确定预测的准确性。一个好的模型应该使得其预测的特定扰动伪批量与该扰动的真实伪批量最接近,而远离其他扰动的真实伪批量。

  • 数学公式: 对于任何扰动 tt,设 pˉt\bar{p}_tptp_t 分别表示预测和观测到的伪批量表达。使用距离度量 d(,)d(\cdot, \cdot)(例如,曼哈顿距离或欧氏距离),定义 rtr_t 为真实伪批量表达中,有多少个其他扰动 tt' 的真实伪批量 ptp_{t'} 比正确的扰动 ptp_t 更接近预测的 pˉt\bar{p}_trt = tt1{d(pˉt,pt) < d(pˉt,pt)}, r_t \ = \ \sum_{t' \neq t} \mathbf{1} \{ d(\bar{p}_t, p_{t'}) \ < \ d(\bar{p}_t, p_t) \},

    • 符号解释:
      • rtr_t:对于扰动 tt,其他扰动的真实伪批量比正确扰动的真实伪批量更接近预测伪批量的数量。

      • 1{}\mathbf{1}\{\cdot\}:指示函数,当括号内条件为真时取1,否则取0。

      • d(,)d(\cdot, \cdot):距离函数(例如曼哈顿距离或欧氏距离)。

      • pˉt\bar{p}_t:模型预测的扰动 tt 的伪批量表达。

      • ptp_t:真实观测到的扰动 tt 的伪批量表达。

      • ptp_{t'}:真实观测到的扰动 tt' 的伪批量表达(ttt' \neq t)。

        每个扰动的分数为: PDisct=rtT, \mathrm{PDisc}_t = \frac{r_t}{T},

    • 符号解释:
      • PDisct\mathrm{PDisc}_t:扰动 tt 的扰动区分分数(归一化秩)。
      • TT:数据集中不同扰动的总数。 该分数的范围是 [0, 1),其中 PDisct=0\mathrm{PDisc}_t = 0 表示完美匹配(没有更接近的廓线),接近1的值表示区分能力差。

    总体分数是所有扰动的平均值: PDisc=1Tt=1TPDisct. \mathrm{PDisc} = \frac{1}{T} \sum_{t=1}^{T} \mathrm{PDisc}_t.

    • 符号解释:
      • PDisc\mathrm{PDisc}:平均扰动区分分数。

        为了方便解释,报告的是归一化的逆扰动区分分数: PDiscNorm=12PDisc. \mathrm{PDiscNorm} = 1 - 2\mathrm{PDisc}.

    • 符号解释:
      • PDiscNorm\mathrm{PDiscNorm}:归一化的逆扰动区分分数。 这样,随机预测器得分为0.0,完美预测器得分为1.0。

5.2.2. 差分表达 (Differential Expression)

CELL-EVAL 使用 Wilcoxon 秩和检验(Wilcoxon rank-sum test)进行差分表达分析,并使用 Benjamini-Hochberg 程序进行多重假设校正,应用于测试细胞系中的观测值和模型预测。

  1. DEG重叠准确率 (DE Overlap Accuracy)

    • 概念定义: 评估模型预测的差分表达基因(DEGs)与真实观测到的 DEGs 之间的重叠程度。对于每个扰动,会识别出前 kkDEGs(根据调整后的 pp 值和绝对 log 倍数变化进行过滤和排序),然后计算预测集合与真实集合之间的交集占 kk 的比例。
    • 数学公式: Overlapt,k = Gt,true(k)Gt,pred(k)k. \mathrm{Overlap}_{t,k} \ = \ \frac{|G_{t,\mathrm{true}}^{(k)} \cap G_{t,\mathrm{pred}}^{(k)}|}{k}.
      • 符号解释:
        • Overlapt,k\mathrm{Overlap}_{t,k}:扰动 tt 的前 kkDEGs 的重叠准确率。
        • Gt,true(k)G_{t,\mathrm{true}}^{(k)}:扰动 tt 的真实前 kkDEGs 集合。
        • Gt,pred(k)G_{t,\mathrm{pred}}^{(k)}:扰动 tt 的预测前 kkDEGs 集合。
        • kk:要考虑的 DEGs 数量(例如50, 100, 200)。当 k=Nk=N 时,表示使用所有被识别为 DEGs 的基因,此时 NN 因扰动而异。 如果未指定 kk,通常指 NNDE 重叠(即所有 DEGs)。
  2. 效应大小 (Effect Sizes) / Spearman相关性 (Spearman correlation)

    • 概念定义: 比较模型预测的扰动效应大小与真实效应大小的相对关系。通过计算每个扰动的差分表达基因数量(调整 pp 值 < 0.05)的 Spearman 秩相关系数来衡量。
    • 数学公式:nt=Gt,true(DE)n_t = |G_{t,\mathrm{true}}^{(\mathrm{DE})}| 是扰动 tt 的真实 DEGs 数量,n^t=Gt,pred(DE)\hat{n}_t = |G_{t,\mathrm{pred}}^{(\mathrm{DE})}| 是预测的 DEGs 数量。 SizeCorr=ρrank((nt)t=1T,(n^t)t=1T). \mathrm{SizeCorr} = \rho_{\mathrm{rank}} \big( (n_t)_{t=1}^T, (\hat{n}_t)_{t=1}^T \big).
      • 符号解释:
        • SizeCorr\mathrm{SizeCorr}:效应大小的 Spearman 相关系数。
        • ρrank\rho_{\mathrm{rank}}Spearman 秩相关函数。
        • (nt)t=1T(n_t)_{t=1}^T:所有扰动的真实 DEGs 数量序列。
        • (n^t)t=1T(\hat{n}_t)_{t=1}^T:所有扰动的预测 DEGs 数量序列。

5.2.3. 细胞嵌入评估指标 (Cell Embedding Evaluation Metrics)

用于评估 SE 模型生成的细胞嵌入的质量和效用。

  1. 内在评估 (Intrinsic Evaluation)

    • 概念定义: 评估嵌入捕获扰动特异性信息的能力。通过训练一个单隐藏层的多层感知器(MLP)来从细胞嵌入中分类扰动标签,并测量 AUROC 和准确率。
    • 方法: 对于每个细胞系单独进行分类任务,数据以20%的细胞随机分配到测试集,80%到训练集。MLP 使用交叉熵损失训练。
    • 意义: 高分类性能表明嵌入保留了扰动诱导状态的独特表示。
  2. 外在评估 (Extrinsic Evaluation)

    • 概念定义: 评估不同细胞嵌入空间对 ST 下游任务(扰动效应建模)的实用性。
    • 方法: 比较 ST 在不同嵌入空间训练时的性能。
    • 意义: 理想的嵌入空间应平衡扰动特异性信息的保留和对扰动诱导状态转换的准确预测能力。

5.2.4. 其他评估 (Other Evaluations)

  • 细胞集合大小缩放 (Cell Set Scaling, Section 5.7.4): 评估 STATE 中每个集合细胞数量的影响。通过消融实验,在保持 batch size 乘以 cell set size 恒定(16,384)的情况下,测量验证损失作为浮点运算(FLOPs)的函数。这些模型还与伪批量模型(用均值池化替代自注意力)和掩蔽注意力版本的 STATE(每个细胞只能关注自身)进行比较。

5.3. 对比基线

论文将 STATE 与以下基线模型进行了比较:

  1. 扰动均值基线 (Perturbation Mean Baseline, Section 5.7.5):

    • 概念: 预测扰动后的表达谱为该细胞背景的控制组均值,加上从训练数据中学习到的一个全局扰动偏移量。
    • 数学公式: 首先计算每个细胞类型 cc 和扰动 pp 的细胞类型特定均值: μcctrl=1CciCcx(i),μc,ppert=1Pc,piPc,px(i), \boldsymbol{\mu}_c^{\mathrm{ctrl}} = \frac{1}{|\mathcal{C}_c|} \sum_{i \in \mathcal{C}_c} \mathbf{x}^{(i)}, \qquad \boldsymbol{\mu}_{c,p}^{\mathrm{pert}} = \frac{1}{|\mathcal{P}_{c,p}|} \sum_{i \in \mathcal{P}_{c,p}} \mathbf{x}^{(i)},
      • 符号解释:
        • μcctrl\boldsymbol{\mu}_c^{\mathrm{ctrl}}:细胞类型 cc 的控制细胞的平均表达。
        • μc,ppert\boldsymbol{\mu}_{c,p}^{\mathrm{pert}}:细胞类型 cc 接受扰动 pp 的受扰细胞的平均表达。
        • Cc\mathcal{C}_c:细胞类型 cc 的控制细胞集合。
        • Pc,p\mathcal{P}_{c,p}:细胞类型 cc 接受扰动 pp 的受扰细胞集合。 然后,计算细胞类型特定偏移量 δc,p=μc,ppertμcctrl\boldsymbol{\delta}_{c,p} = \boldsymbol{\mu}_{c,p}^{\mathrm{pert}} - \boldsymbol{\mu}_c^{\mathrm{ctrl}},并对所有扰动 pp 上的细胞类型求平均,得到一个全局偏移量: δp = 1CpcCpδc,p,Cp={cPc,p>0}, \boldsymbol{\delta}_p \ = \ \frac{1}{|\mathcal{C}_p|} \sum_{c \in \mathcal{C}_p} \boldsymbol{\delta}_{c,p}, \qquad \mathcal{C}_p = \{c \mid |\mathcal{P}_{c,p}| > 0 \},
      • 符号解释:
        • δp\boldsymbol{\delta}_p:扰动 pp 的全局平均偏移量。
        • Cp\mathcal{C}_p:至少包含扰动 pp 的细胞类型集合。 给定测试细胞类型 tt 和扰动标签 pp,模型输出: x^=μtctrl+δp,δctrl0. \hat{\mathbf{x}} = \boldsymbol{\mu}_t^{\mathrm{ctrl}} + \boldsymbol{\delta}_p, \qquad \boldsymbol{\delta}_{\mathrm{ctrl}} \equiv \mathbf{0}.
      • 符号解释:
        • x^\hat{\mathbf{x}}:预测的扰动后表达。
        • μtctrl\boldsymbol{\mu}_t^{\mathrm{ctrl}}:测试细胞类型 tt 的控制细胞的平均表达。 这意味着控制细胞被精确复制,而每个非控制扰动接收相同的全局偏移量。
  2. 背景均值基线 (Context Mean Baseline, Section 5.7.5):

    • 概念: 通过返回训练集中观测到的相同细胞类型受扰细胞的平均表达来预测细胞的扰动后表达。
    • 数学公式: 对于每个细胞类型 cc,收集所有非控制扰动的训练细胞,形成伪批量均值: μc=1TciTcx(i),Tc={icell_type(i)=c,p(i)ctrl}. \boldsymbol{\mu}_c = \frac{1}{|\mathcal{T}_c|} \sum_{i \in \mathcal{T}_c} \mathbf{x}^{(i)}, \qquad \mathcal{T}_c = \{i \big| \mathrm{cell}\_\mathrm{type}(i) = c, p^{(i)} \neq \mathrm{ctrl} \}.
      • 符号解释:
        • μc\boldsymbol{\mu}_c:细胞类型 cc 的非控制扰动训练细胞的伪批量均值。
        • Tc\mathcal{T}_c:细胞类型为 cc 且非控制扰动的训练细胞集合。 在推理时,对于细胞类型为 c(i)c^{(i)} 和扰动标签为 p(i)p^{(i)} 的测试细胞 ii,预测为: x^(i)={x(i)p(i)=ctrl,μc(i)p(i)ctrl. \hat{\mathbf{x}}^{(i)} = \left\{ \begin{array}{l l} \mathbf{x}^{(i)} & p^{(i)} = \mathrm{ctrl}, \\ \boldsymbol{\mu}_{c^{(i)}} & p^{(i)} \neq \mathrm{ctrl}. \end{array} \right.
      • 符号解释:
        • x^(i)\hat{\mathbf{x}}^{(i)}:预测的细胞 ii 的扰动后表达。 即控制细胞保持不变,而受扰细胞继承其细胞类型均值。
  3. 线性基线 (Linear Baseline, Section 5.7.5):

    • 概念: 将扰动视为低秩、基因范围的线性位移,叠加到细胞的控制表达上。
    • 数学公式:GRG×dgG \in \mathbb{R}^{G \times d_g} 是固定的基因嵌入矩阵,PRP×dpP \in \mathbb{R}^{P \times d_p} 是固定的扰动嵌入矩阵。从训练集首先构建一个“表达变化”伪批量 YRG×PY \in \mathbb{R}^{G \times P}Yg,p=1PpiPp(xgpert,(i)xgctrl,(i)),Pp={ip(i)=p}, Y_{g,p} = \frac{1}{|\mathcal{P}_p|} \sum_{i \in \mathcal{P}_p} \bigl( x_g^{\mathrm{pert},(i)} - x_g^{\mathrm{ctrl},(i)} \bigr), \qquad \mathcal{P}_p = \bigl\{ i \mid p^{(i)} = p \bigr\},
      • 符号解释:
        • Yg,pY_{g,p}:基因 gg 和扰动 pp 的平均表达变化。
        • Pp\mathcal{P}_p:接受扰动 pp 的细胞集合。
        • xgpert,(i)x_g^{\mathrm{pert},(i)}:细胞 ii 扰动后基因 gg 的表达。
        • xgctrl,(i)x_g^{\mathrm{ctrl},(i)}:细胞 ii 扰动前(控制)基因 gg 的表达。 模型寻找一个低秩映射 KRdg×dpK \in \mathbb{R}^{d_g \times d_p} 和一个基因范围的偏置 bRG\mathbf{b} \in \mathbb{R}^G 使得 YGKP+b1Y \approx GKP^\top + \mathbf{b}\mathbf{1}^\topKK 通过求解岭正则化最小二乘问题获得: minKYGKPb1F2+λKF2,b=1PY1, \operatorname*{min}_{K} \left\| Y - GKP^\top - \mathbf{b}\mathbf{1}^\top \right\|_F^2 + \lambda \|K\|_F^2, \qquad \mathbf{b} = \frac{1}{P} Y \mathbf{1}, 其闭式解为: K=(GG+λI)1GYP(PP+λI)1. K = \big( G^\top G + \lambda I \big)^{-1} G^\top Y P \big( P^\top P + \lambda I \big)^{-1}.
      • 符号解释:
        • F\| \cdot \|_F:Frobenius 范数。
        • λ\lambda:岭正则化参数。
        • II:单位矩阵。
        • 1\mathbf{1}:全1向量。 对于测试细胞 ii 及其控制表达谱 xctrl,(i)\mathbf{x}^{\mathrm{ctrl},(i)} 和扰动标签 p(i)p^{(i)},预测为: x^(i)={xctrl,(i)p(i)=ctrl,xctrl,(i)+GKPp(i)+bp(i)ctrl, \begin{array}{r l r} \hat{\mathbf{x}}^{(i)} & = & \left\{ \begin{array}{l l} \mathbf{x}^{\mathrm{ctrl},(i)} & p^{(i)} = \mathrm{ctrl}, \\ \mathbf{x}^{\mathrm{ctrl},(i)} + GKP_{p^{(i)}} + \mathbf{b} & p^{(i)} \neq \mathrm{ctrl}, \end{array} \right. \end{array}
      • 符号解释:
        • Pp(i)P_{p^{(i)}}:对应扰动 p(i)p^{(i)}PP 矩阵的行。 每个预测保留细胞的基础状态,并添加从训练数据中学习到的扰动特定、低秩的偏移量。
  4. 深度学习基线 (Deep Learning Baselines, Section 5.7.5):

    • scVI [25]: 一种基于变分自编码器(VAE)的模型,能够建模基因表达分布,同时考虑技术噪声和批次效应。
    • CPA [47]: 一种组合扰动自编码器(Compositional Perturbation Autoencoder),学习一个组合潜在空间,捕捉扰动、剂量和细胞类型的加性效应。
    • scGPT [11]: 一个基于 Transformer 的单细胞基础模型,通过在超过3300万细胞上进行生成式预训练,支持包括扰动预测在内的零样本任务泛化。

5.4. 训练

所有模型均使用 PyTorch Lightning 和分布式数据并行(DDP)训练实现。使用 PyTorch 的自动混合精度(AMP)以减少内存使用和加速训练。

5.4.1. ST超参数 (ST Hyperparameters, Section 5.6.1)

以下是原文 Table 3 的结果:

Dataset cell_ set_size hidden_dim n_encoder layers n_decoder layers batch encoder transformer backbone _key attn_heads params
Tahoe-100M 256 1488 4 4 false LLaMA 12 244M
Parse-PBMC 512 1440 4 true LLaMA 12 244M
Replogle-Nadig 32 128 4 4 false GPT2 8 10M
  • ST 架构在模块间使用共享隐藏维度 hh

  • 核心 Transformer 模块 fSTf_{\mathrm{ST}} 基于 LLaMA [66] 或 GPT2 [67] 主干网络。对于稀疏数据集(如 Replogle-Nadig),使用 GPT2 主干网络。

  • 所有模型都修改为使用双向注意力。

  • 由于集合内细胞顺序是任意的,不使用位置编码。Transformer 内不应用 dropout

  • 模型参数初始化:除 Transformer 主干网络(从 N(0,0.022)\mathcal{N}(0, 0.02^2) 初始化)外,其他部分使用 Kaiming Uniform [69] 初始化。

    以下是原文 Table 4 的结果:

    Component Architecture Layer Dimensions Activation Normalization Dropout
    fcell 4-layer MLP (G or E) → h → h → h → h GELU LayerNorm None
    fpert 4-layer MLP Dpert → h → h → h → h GELU LayerNorm None
    fbatch Embedding Lyr. Dbatch → h N/A N/A None
    fsT LLaMA Transf. h (input to each of 4 layers) SwiGLU RMSNorm None
    recon Linear Layer h → G or E N/A N/A None
    fdecode 3-layer MLP h → 1024 → 512 → G GELU LayerNorm 0.1
    fconf 3-layer MLP h → h/2 → h/4 → 1 GELU LayerNorm None

5.4.2. ST训练细节 (ST Training Details, Section 5.6.2)

  • 所有组件(控制细胞编码器 fcellf_{\mathrm{cell}}、扰动编码器 fpertf_{\mathrm{pert}}、可选批次编码器 fbatchf_{\mathrm{batch}}Transformer 主干网络 fSTf_{\mathrm{ST}}、重构和解码头 freconf_{\mathrm{recon}}fdecodef_{\mathrm{decode}})都使用前面描述的目标进行端到端训练。
  • 微调任务: 对于微调任务,使用预训练权重初始化新的 ST 模型,同时选择性地重新初始化特定组件。
    • 扰动编码器 fpertf_{\mathrm{pert}} 重新初始化,以实现跨扰动模态的迁移。
    • 如果 ST 使用细胞嵌入进行训练,基因解码器 fdecodef_{\mathrm{decode}} 也重新初始化,以适应数据集之间基因覆盖的差异。
    • 其他组件(fcellf_{\mathrm{cell}}fbatchf_{\mathrm{batch}}fSTf_{\mathrm{ST}})保留其预训练权重,并在目标数据集上进行微调。

5.4.3. SE超参数 (SE Hyperparameters, Section 5.6.3)

  • SE 是一个6亿参数的编码器-解码器模型。
  • 编码器包含16个 Transformer 层,每个层有16个注意力头,隐藏维度 h=2048h = 2048
  • 每个层使用预归一化(pre-normalization),前馈网络扩展到 3×h3 \times h 维度,并使用 GELU 激活。
  • 注意力层和前馈层均应用0.1的 dropout 概率。
  • 解码器是一个多层感知器(MLP),训练用于在给定学习到的细胞嵌入和目标基因嵌入的情况下恢复基因表达。
  • 优化器: 使用 AdamW 优化器 [70],最大学习率为 10510^{-5},权重衰减为0.01,梯度裁剪使用 zclip [71]。
  • 学习率调度: 线性预热(linear warmup)占总步数的3%,随后是余弦退火(cosine annealing)到最大学习率的30%。
  • 权重初始化: 训练前,所有 SE 权重均从 Kaiming Uniform 采样初始化。

5.4.4. SE训练细节 (SE Training Details, Section 5.6.4)

  • SE 在包含14,420个 AnnData 文件、跨1.67亿人类细胞的大规模语料库上进行了训练,这些数据来自 Arc scBaseCount [32]、CZ CELLxGENE [31] 和 Tahoe-100M [30] 数据集,共训练4个 epochs

  • 为避免数据泄露,数据集在数据集级别被拆分为独立的训练集和验证集。

  • 高效训练: 使用 Flash Attention 2 [72] 和混合精度(bf16)训练 [73]。

  • 分布式训练: 训练分布在4个计算节点上,每个节点配备8个 NVIDIA H100 GPU。

  • 批次大小: 使用有效批次大小3,072,每设备批次大小为24,并进行4步梯度累积。


6. 实验结果与分析

6.1. 核心结果分析

以下是原文 Figure 2 的示意图,展示了 STATE 模型在不同细胞背景下预测扰动效果的性能提升:

Figure 2: STATE improves perturbation prediction in context generalization, leverages data scale, and enables cross-dataset transer learning. (A) Underrepresented context generalization task.Models w… 该图像是论文中图2,属于多部分图表,展示了STATE模型在不同细胞背景下预测扰动效果的性能提升,包括数据规模利用、跨数据集迁移学习和零样本扰动预测,图中涉及模型训练策略、评估流程及多个性能指标的对比。

  • A) 代表性不足的上下文泛化任务: 模型在部分扰动被大量保留在独立测试上下文中的情况下进行训练和评估。
  • B) 评估过程: 模型在30%的数据上训练和评估。
  • C) 扰动评估指标: 展示了 STATE 和基线模型在扰动区分、预测与真实 DEGs 重叠以及扰动效应大小的 Spearman 相关性方面的性能。
  • D) 细胞嵌入评估: 左侧(内在性能)衡量在观测数据生成的嵌入中扰动细胞的分类准确率;右侧(外在性能)衡量在 ST 基于细胞嵌入训练的预测嵌入中分类准确率。
  • E) 零样本扰动预测:Tahoe-100M 上预训练后,STATE 在零样本任务上的表现。

6.1.1. STATE 在上下文泛化上的改进和对数据规模的利用 (Section 2.2, Figure 2C)

  • 任务设置: 模型在代表性不足的上下文泛化任务上进行评估。每个测试细胞上下文贡献其30%的扰动用于训练,其余部分作为测试集。
  • 扰动区分分数 (Perturbation Discrimination Score):
    • Tahoe-100M 数据集上,STATE 实现了54%的绝对提升。
    • Parse-PBMC 数据集上,STATE 实现了29%的绝对提升。
    • 在遗传扰动数据集上,STATE 的表现与次优基线相当。
  • DEG 重叠准确率 (DEG Overlap Accuracy):
    • Tahoe-100M 数据集上,STATE 的表现比次优基线好两倍。
    • Parse-PBMC 数据集上,STATE 的表现比次优基线好43%。
    • 在遗传扰动数据集上,STATE 是第二好的模型。
  • 效应大小的 Spearman 相关性 (Spearman Correlation for Effect Sizes):
    • STATE 准确地根据相对效应大小对扰动进行排序。
    • Parse-PBMC 上比基线高53%。
    • ReplogleNadig 上比基线高22%。
    • Tahoe-100M 上比基线高70%,接近0.8的绝对相关性。
  • 结论:
    • STATE 在数据规模增加时(从遗传扰动到信号扰动再到药物扰动,数据量增加一到两个数量级)性能提升最为显著。
    • 这表明当前扰动模型可能处于数据稀疏状态 [49, 50],而 STATE 架构能更好地利用大规模数据。

6.1.2. 共享细胞嵌入 (SE) 实现跨数据集迁移 (Section 2.3, Figure 2D, 2E)

  • 内在评估 (Intrinsic Evaluation, Figure 2D):
    • 通过训练浅层 MLP 分类器来预测扰动标签,评估 SE 嵌入的信息含量。
    • STATE 嵌入在内在性能上显著优于其他细胞基础模型,甚至优于使用高度可变基因的数据集特定表示。这表明 SE 嵌入有效地捕获了扰动特异性信息。
  • 迁移学习能力 (Transfer Learning, Figure 2E):
    • 为了评估在 SE 潜在空间中训练扰动模型所实现的迁移学习能力,ST+SEST+SE 模型在 Tahoe-100M 上进行预训练,然后在几个较小数据集上进行微调。
    • 在所有测试数据集上,ST+SEST+SE 都实现了比 ST+HVGST+HVG 更好的迁移,并优于其他基线方法。
    • 这证明 SE 能够统一不同数据集的细胞表示,实现更鲁棒的跨数据集迁移。

6.1.3. 细胞集合大小的影响 (Cell Set Scaling, Figure 1B)

  • 结果: 随着细胞集合大小的增加(直到256个细胞),验证损失显著降低。
  • 对比:
    • 完整的 ST 模型显著优于伪批量模型(STATE w/ mean-pooling,用均值池化代替自注意力)。
    • 完整的 ST 模型显著优于单细胞变体(STATEwithsetsize=1STATE with set size = 1)。
    • 移除自注意力机制(STATE w/o self-attention)会导致性能大幅下降。
  • 结论: ST 中自注意力机制在处理细胞集合时,能够有效地捕捉和利用群体内的异质性信息,从而提升模型性能。存在一个最佳的细胞集合大小,平衡了信息捕获和计算效率。

6.2. 数据呈现 (表格)

本部分将再次呈现实验设置中已经列出的表格,因为原文的实验结果部分主要通过图表而非直接的数值表格来展示关键发现。图表结果已经在上一节的核心结果分析中进行了详细的文字描述。

6.2.1. 用于ST训练的数据集 (Table 1)

以下是原文 Table 1 的结果:

Dataset # of Cells # of Perturbations # of Contexts
Replogle-Nadig 624,158 1,677 4
Jiang 234,845 24 30
Srivatsan 762,795 189 3
Mcfaline-Figueroa 354,758 122 3
Tahoe-100M 100,648,790 1138 50
Parse-PBMC 9,697,974 90 12/18 (Donors)/(Cell Types)

6.2.2. 用于SE训练的数据集 (Table 2)

以下是原文 Table 2 的结果:

Dataset # Training Cells # Validation Cells
Arc scBaseCount [32] 71,676,369 4,137,674
CZ CellXGene [31] 59,233,790 6,500,519
Tahoe-100M [30] 36,157,383 2,780,587

6.2.3. ST关键模型超参数 (Table 3)

以下是原文 Table 3 的结果:

Dataset cell_ set_size hidden_dim n_encoder layers n_decoder layers batch encoder transformer backbone _key attn_heads params
Tahoe-100M 256 1488 4 4 false LLaMA 12 244M
Parse-PBMC 512 1440 4 true LLaMA 12 244M
Replogle-Nadig 32 128 4 4 false GPT2 8 10M

6.2.4. ST组件架构详情 (Table 4)

以下是原文 Table 4 的结果:

Component Architecture Layer Dimensions Activation Normalization Dropout
fcell 4-layer MLP (G or E) → h → h → h → h GELU LayerNorm None
fpert 4-layer MLP Dpert → h → h → h → h GELU LayerNorm None
fbatch Embedding Lyr. Dbatch → h N/A N/A None
fsT LLaMA Transf. h (input to each of 4 layers) SwiGLU RMSNorm None
recon Linear Layer h → G or E N/A N/A None
fdecode 3-layer MLP h → 1024 → 512 → G GELU LayerNorm 0.1
fconf 3-layer MLP h → h/2 → h/4 → 1 GELU LayerNorm None

6.3. 消融实验/参数分析

论文通过以下实验验证了模型各组件的有效性和参数选择的合理性:

  1. 细胞集合大小的消融实验 (Figure 1B):
    • 目的: 评估 STATE 中每个细胞集合大小对性能的影响。
    • 方法: 通过改变每个集合中的细胞数量(cell_set_size),同时保持 batch_size 乘以 cell_set_size 恒定为16,384,来衡量验证损失作为浮点运算(FLOPs)的函数。
    • 结果: 随着细胞集合大小的增加,验证损失显著降低,直到在256个细胞时达到最佳性能。
    • 结论: 存在一个最优的细胞集合大小,使得模型能够有效捕捉异质性而不引入过多噪声或计算负担。
  2. ST 模型核心组件的消融:
    • STATE 与伪批量模型(STATE w/ mean-pooling)对比: 伪批量模型用均值池化(mean-pooling)代替了 ST 中的自注意力机制。结果显示,完整的 ST 模型显著优于伪批量模型。

    • STATE 与单细胞变体(STATEwithsetsize=1STATE with set size = 1)对比: 单细胞变体将集合大小设为1。结果显示,完整的 ST 模型显著优于单细胞变体。

    • STATE 移除自注意力机制(STATE w/o self-attention)对比: 移除自注意力机制后,性能大幅下降。

    • 结论: 这些消融实验强有力地证明了 ST 中使用自注意力机制处理细胞集合的重要性。自注意力机制能够有效地建模和利用细胞群体内部的异质性,而简单地对细胞进行平均(伪批量)或独立处理(单细胞)则会丢失关键信息,导致性能下降。

      这些消融实验验证了 STATE 架构中关键设计选择的合理性,特别是基于集合的自注意力机制和对最优集合大小的探索。


7. 总结与思考

7.1. 结论总结

本文介绍了 STATE 模型,一个基于 Transformer 的多尺度机器学习架构,旨在解决现有模型在预测细胞对扰动响应时跨多样化细胞背景泛化能力不足的问题。STATE 由两个核心模块组成:State Transition (ST) 模型和 State Embedding (SE) 模型。ST 利用自注意力机制在细胞集合上学习扰动效应,从而能够灵活地捕捉细胞群体内部和跨实验的异质性。SE 则是一个预训练的编码器-解码器模型,通过学习大规模观测数据中的基因表达变异来生成鲁棒的细胞嵌入,有效应对技术变异并优化扰动效应的检测。

实验结果表明,STATE 在多个大规模数据集上取得了显著的性能提升。它在扰动区分度、差分表达基因识别的准确性以及扰动效应大小的排序方面均超越了现有的基线模型,尤其是在数据规模增大时,其优势更为明显。STATE 结合 SE 提供的通用细胞嵌入,展现了在训练期间未见的新细胞背景中预测强扰动的零样本泛化能力。此外,模型对细胞集合大小的消融实验证实了自注意力机制在处理细胞异质性方面的关键作用。

总而言之,STATE 提供了一种可扩展、高性能且灵活的方法,用于构建细胞状态和行为的基础模型,这为开发用于多样化生物应用(包括虚拟细胞模型和指导自主实验设计)的下一代工具奠定了基础。

7.2. 局限性与未来工作

局限性:

  1. 数据依赖性: 尽管 STATE 在大规模数据集上表现出色,但论文也指出“在遗传扰动等效应较弱且数据量较少的情况下,STATE 的表现虽好,但往往仅能与次优基线持平或略优”。这暗示在数据量稀疏或扰动效应不显著的场景下,模型的优势可能不那么突出。
  2. 理论与实践的差距: 理论分析表明 ST 的解族包含最优传输映射,但并不保证模型能学到唯一的 OT 解,因为解族可能包含无限多的“过拟合”解决方案。虽然隐式偏差可能引导模型学习到“最小整体位移”的解,但这并未得到严格证明,且额外的 Jacobian 约束在实际模型中如何高效实现仍是开放问题。
  3. 计算资源需求: STATE 的训练涉及超过1亿的扰动细胞和1.67亿的观测细胞,以及一个6亿参数的 SE 模型,需要多 GPU 集群进行分布式训练,这对于小型研究团队来说可能是一个巨大的计算负担。
  4. 模型复杂度与可解释性: Transformer 模型虽然强大,但其内部机制复杂,可能降低了模型的可解释性。理解 STATE 如何捕捉和利用细胞异质性以及其决策过程,可能仍是一个挑战。
  5. 基因选择的潜在局限: SE 模型通过选择 topL=2048top L=2048 个高表达基因来构建细胞表示,这可能忽略了低表达但对细胞状态或扰动响应至关重要的基因。

未来工作:

  1. 实现虚拟细胞模型:STATE 进一步发展为能够探索细胞状态空间并指导自主实验设计的虚拟细胞模型 [51, 52, 53-56]。
  2. 表征隐式偏差: 深入研究和理论表征 ST 模型优化中的隐式偏差 [85-87],以更好地理解梯度下降如何引导模型学习到类似最优传输的映射,并可能发现某种“未知成本函数”下的最优传输。
  3. 架构改进以强制 OT 映射: 探索新的架构设计,能够直接在模型中实现 Jacobian 约束,从而强制模型学习到唯一的连续最优传输映射,如 Input Convex Neural Networks (ICNNs) [9, 84]。
  4. 拓展应用场景:STATE 应用于更广泛的生物学问题,如疾病进展预测、个性化治疗推荐等。

7.3. 个人启发与批判

个人启发:

  1. Transformer在生物学领域的潜力: 这篇论文进一步印证了 Transformer 架构在处理复杂生物学数据方面的强大能力。它不仅限于序列数据(如基因序列),还能有效地建模细胞群体这种“无序集合”的复杂相互作用和异质性。
  2. 预训练范式的成功迁移: SE 模块通过在大规模观测数据上进行自监督预训练,极大地提高了模型的泛化能力和数据效率。这种“预训练-微调”的范式已经在大语言模型(LLMs)中取得了巨大成功,现在看来,它同样适用于解决生物学领域中跨数据集异质性和数据稀疏性(对于特定任务)的挑战。
  3. 异质性建模的关键性: 论文深刻地揭示了细胞异质性在扰动响应预测中的重要性。通过 ST 模型对细胞集合进行建模,并利用自注意力机制捕捉残余异质性,STATE 提供了一个强大的工具来超越传统方法对细胞同质性的简化假设。
  4. 理论与实践的结合: 论文不仅提出了一个高性能的模型,还进行了深入的理论分析,探讨了 ST 与最优传输的关联。这种理论指导的建模方式有助于理解模型的内在机制和潜在能力,也为未来的改进指明了方向。

批判性思考与潜在改进:

  1. 计算成本与可及性: STATE 的训练需要极其庞大的数据集(1亿+细胞)和高性能计算资源(多 H100 GPU)。这使得模型的复现和进一步研究对大多数实验室而言是极大的挑战。未来工作可以探索更计算高效的架构或知识蒸馏(knowledge distillation)方法,将 STATE 的能力转移到更小的模型中,以提高其可及性。
  2. 可解释性挑战: Transformer 模型固有的黑箱特性使得理解模型为何做出特定预测变得困难。在生物学领域,理解“为什么”一个扰动会导致某种响应与“预测”这种响应同样重要。未来可以结合可解释性 AI(XAI)技术,例如注意力权重分析、特征归因方法,来深入理解 STATE 模型的决策过程,揭示潜在的生物学机制。
  3. 基因表达排序的敏感性: SE 模块依赖于选择 top L 个高表达基因来构建细胞表示。这种策略可能导致对低表达但具有关键调控作用的基因的忽略。未来的研究可以探索不依赖表达水平排序的基因表示方法,或者通过多模态数据集成来纳入更多维度的基因信息。
  4. MMD 损失的优化挑战: MMD 损失在优化非凸问题时可能面临收敛到局部最优的风险。尽管 STATE 表现良好,但对其优化景观的更深入理解,以及探索结合其他损失函数(例如 Wasserstein 距离或其他正则化技术)以改善优化过程,可能会进一步提升模型性能和鲁棒性。
  5. 因果推断的强化: 尽管 STATE 提高了扰动预测的准确性,但它主要侧重于预测扰动“结果”。未来可以在模型中融入更强的因果推断框架,例如通过引入贝叶斯网络或结构因果模型,以更好地理解扰动与响应之间的因果关系,而不仅仅是相关性。理论分析中提及的 Jacobian 约束,正是朝着这一方向的初步探索。

相似论文推荐

基于向量语义检索推荐的相关论文。

暂时没有找到相似论文。