Conditional out-of-sample generation for unpaired data using trVAE
TL;DR 精炼摘要
本文提出trVAE,通过在解码器层引入最大均值差异(MMD)正则化,实现不同条件间分布匹配,解决条件变分自编码器在样本外生成中的泛化不足问题。trVAE在高维图像和单细胞基因表达数据上表现出更优的鲁棒性和预测准确性。
摘要
While generative models have shown great success in generating high-dimensional samples conditional on low-dimensional descriptors (learning e.g. stroke thickness in MNIST, hair color in CelebA, or speaker identity in Wavenet), their generation out-of-sample poses fundamental problems. The conditional variational autoencoder (CVAE) as a simple conditional generative model does not explicitly relate conditions during training and, hence, has no incentive of learning a compact joint distribution across conditions. We overcome this limitation by matching their distributions using maximum mean discrepancy (MMD) in the decoder layer that follows the bottleneck. This introduces a strong regularization both for reconstructing samples within the same condition and for transforming samples across conditions, resulting in much improved generalization. We refer to the architecture as \emph{transformer} VAE (trVAE). Benchmarking trVAE on high-dimensional image and tabular data, we demonstrate higher robustness and higher accuracy than existing approaches. In particular, we show qualitatively improved predictions for cellular perturbation response to treatment and disease based on high-dimensional single-cell gene expression data, by tackling previously problematic minority classes and multiple conditions. For generic tasks, we improve Pearson correlations of high-dimensional estimated means and variances with their ground truths from 0.89 to 0.97 and 0.75 to 0.87, respectively.
思维导图
论文精读
中文精读
1. 论文基本信息
1.1. 标题
Conditional out-of-sample generation for unpaired data using trVAE
1.2. 作者
M. Lotfollahi, Mohsen Naghipourfar, Fabian J. Theis, F. Alexander Wolf 等。 作者主要来自德国慕尼黑亥姆霍兹中心计算生物学研究所、慕尼黑工业大学生命科学魏恩施蒂芬学院、慕尼黑工业大学数学系,以及伊朗沙里夫理工大学计算机工程系。
1.3. 发表期刊/会议
该论文发布在 arXiv 预印本平台,发布时间为 2019 年 10 月 4 日。虽然未在具体期刊或会议上正式发表,但 arXiv 是学术研究成果早期传播的重要平台。
1.4. 发表年份
2019 年
1.5. 摘要
生成模型 (generative models) 在根据低维描述符生成高维样本方面取得了巨大成功,例如在 MNIST 数据集中学习笔画粗细,在 CelebA 数据集中学习发色,或在 Wavenet 中学习说话者身份。然而,其样本外生成 (out-of-sample generation) 却存在根本性问题。条件变分自编码器 (Conditional Variational Autoencoder, CVAE) 作为一种简单的条件生成模型,在训练过程中没有明确关联不同条件,因此缺乏学习跨条件紧凑联合分布的动力。本文通过在瓶颈层 (bottleneck layer) 之后的解码器层 (decoder layer) 使用最大均值差异 (Maximum Mean Discrepancy, MMD) 来匹配它们的分布,从而克服了这一限制。这引入了强大的正则化 (regularization),既用于重建同一条件下的样本,也用于在不同条件间转换样本,从而显著提高了泛化能力 (generalization)。该架构被称为 transformer VAE (trVAE)。通过在高维图像和表格数据上对 trVAE 进行基准测试,我们证明了其比现有方法具有更高的鲁棒性 (robustness) 和准确性 (accuracy)。特别是,我们展示了基于高维单细胞基因表达数据,对细胞扰动响应(对治疗和疾病)进行了定性改进的预测,解决了以前棘手的少数类 (minority classes) 和多条件问题。对于通用任务,高维估计均值和方差与其真实值 (ground truths) 的皮尔逊相关系数 (Pearson correlations) 分别从 0.89 提高到 0.97 和从 0.75 提高到 0.87。
1.6. 原文链接
https://arxiv.org/abs/1910.01791v2 PDF 链接: https://arxiv.org/pdf/1910.01791v2.pdf 发布状态:预印本 (Preprint)
2. 整体概括
2.1. 研究背景与动机
2.1.1. 核心问题
论文旨在解决生成模型在进行样本外生成 (out-of-sample generation) 时所面临的根本性挑战。具体来说,当训练数据中不包含目标条件 (target condition) 的特定样本时,如何准确地生成这些样本。例如,如果训练数据中只有黑发男性、金发女性和黑发女性,但没有金发男性,模型是否能预测一个黑发男性在拥有金发时会是什么样子。在生物学领域,这尤其重要,例如根据体外 (in vitro) 或小鼠实验数据,预测未治疗的人类如何响应药物治疗。
2.1.2. 问题重要性与现有挑战
- 传统条件生成模型 (Conditional Generative Models) 的局限性: 传统的条件变分自编码器 (CVAE) 等模型虽然可以根据低维描述符生成高维样本,但在训练时并未明确关联不同条件。这意味着 CVAE 缺乏学习跨条件紧凑联合分布 (compact joint distribution across conditions) 的内在机制,导致其在处理训练数据中未见过的条件组合时表现不佳。
- 样本外泛化需求: 许多实际应用,特别是科学发现领域(如药物研发、细胞扰动预测),都要求模型能够从有限的、有偏的或不完全的数据中,泛化到新的、未见的条件,进行准确预测或生成。
- 生物学应用的迫切性: 在单细胞基因表达数据分析中,预测细胞对治疗和疾病的反应,特别是对于少数细胞类型或多种扰动条件下的反应,是理解生物学机制的关键。现有方法往往难以处理这些复杂场景,尤其是当训练数据中缺乏特定扰动条件下的细胞类型样本时。
2.1.3. 论文的切入点与创新思路
论文的创新点在于,通过在条件变分自编码器 (CVAE) 的解码器层 (decoder layer) 中引入最大均值差异 (Maximum Mean Discrepancy, MMD) 正则化,来强制匹配不同条件下的中间表示分布。
- MMD 的引入位置: 与以往在潜在空间 (latent space) 或瓶颈层 (bottleneck layer) 应用 MMD 的方法不同,本文将其置于解码器的第一层(即瓶颈层之后)。作者认为, 空间倾向于学习条件无关的特征,而条件信息 在 层(解码器第一层)才开始强烈影响表示。因此,在 层进行 MMD 正则化能够更有效地促使模型学习跨条件的共同特征,从而实现更准确的样本外转换。
- 数据驱动的端到端方法: 这种方法提供了一个数据驱动的、端到端 (end-to-end) 的解决方案,避免了硬编码 (hard-coded) 元素和对向量算术的依赖,并且能够泛化到多于两个的条件。
2.2. 核心贡献/主要发现
- 提出了 trVAE (transformer VAE) 模型: 一种基于 CVAE 并结合 MMD 正则化的新型生成模型,专门用于处理无配对数据的条件样本外生成任务。
- MMD 正则化在解码器层的有效性: 论文通过实验证明,将 MMD 应用于解码器第一层(而不是传统的瓶颈层)能够有效正则化跨条件的表示分布,促使模型学习更紧凑、更具泛化能力的跨条件特征。
- 在图像和表格数据上的优越性能:
- 图像风格迁移: 在 Morpho-MNIST 和 CelebA 数据集上,trVAE 成功实现了高质量的样本外风格迁移,即使目标条件(如金发男性、微笑女性)在训练数据中完全缺失。
- 生物学扰动预测: 在高维单细胞基因表达数据(肠道细胞感染响应、PBMC 刺激响应)上,trVAE 能够准确预测细胞对治疗和疾病的反应,显著优于现有方法,并能有效处理少数类和多条件场景。
- 量化性能提升: 在通用任务中,trVAE 将高维估计均值和方差与真实值之间的皮尔逊相关系数分别从 0.89 提高到 0.97,从 0.75 提高到 0.87,显示出显著的性能提升。
- 更高的鲁棒性和准确性: 基准测试表明 trVAE 比包括 CVAE、MMD-CVAE、CycleGAN、scGen 和 scVI 在内的现有方法具有更高的鲁棒性和准确性。
- 数据驱动和端到端方法: trVAE 提供了一个不涉及硬编码元素,且能够泛化到多个条件的数据驱动、端到端的方法。
3. 预备知识与相关工作
3.1. 基础概念
3.1.1. 变分自编码器 (Variational Autoencoder, VAE)
变分自编码器 (VAE) 是一种生成模型,旨在学习数据 的潜在表示 ,并能够从这些潜在表示中生成新的、类似原始数据的样本。其核心思想是利用变分推断 (variational inference) 来近似复杂分布,并通过神经网络实现编码器 (encoder) 和解码器 (decoder) 功能。
- 编码器 (Encoder): 将输入数据 映射到潜在空间 的参数化分布 ,通常假设为高斯分布,从而输出潜在变量的均值和方差。
- 解码器 (Decoder): 从潜在空间中的样本 重建原始数据 的参数化分布 。
- 目标函数 (Objective Function) - ELBO (Evidence Lower Bound): VAE 的训练目标是最大化数据的对数似然 (log-likelihood) 。由于直接计算这个似然很困难,VAE 优化的是其下界,即证据下界 (ELBO)。ELBO 包含两部分:
- 重建损失 (Reconstruction Loss): 。衡量解码器从潜在表示重建原始数据的质量。
- KL 散度损失 (KL Divergence Loss): 。衡量编码器输出的潜在分布 与预设的先验分布 (通常是标准正态分布)之间的差异。这部分损失鼓励潜在空间具有良好的结构和泛化能力。
3.1.2. 条件变分自编码器 (Conditional Variational Autoencoder, CVAE)
条件变分自编码器 (CVAE) 是 VAE 的直接扩展,它允许模型在生成样本时考虑额外的条件信息 。在 CVAE 中,条件 被整合到编码器和解码器中。
- 编码器: ,潜在表示 现在不仅取决于输入 ,还取决于条件 。
- 解码器: ,生成的数据 同样受到潜在表示 和条件 的共同影响。
- 目标函数: CVAE 的目标函数与 VAE 类似,但所有分布都以条件 为条件。
3.1.3. 最大均值差异 (Maximum Mean Discrepancy, MMD)
最大均值差异 (MMD) 是一种衡量两个分布之间距离的度量。它通过将分布映射到一个再生核希尔伯特空间 (Reproducing Kernel Hilbert Space, RKHS) 中,然后计算它们在该空间中均值嵌入 (mean embeddings) 的距离来比较分布。
- 再生核希尔伯特空间 (RKHS): 这是一个特殊的函数空间,其中的函数评估是连续的,并且可以通过核函数 (kernel function) 定义内积。核函数
k(x, x')衡量两个样本 和 之间的相似性。 - MMD 的性质: 对于一个“通用核” (universal kernel,如高斯核),当且仅当两个分布完全相同时,MMD 的值为 0。这使得 MMD 成为一个强大的统计检验和损失函数。
- MMD 公式: 论文中给出的 MMD 损失函数 用于估计两个样本集 和 之间的距离:
k(x, x'): 核函数,衡量 和 之间的相似性。- : 样本集 的大小。
- : 样本集 的大小。
- : 样本集 内部所有样本对的核函数和。
- : 样本集 内部所有样本对的核函数和。
- : 样本集 和 之间所有样本对的核函数和。 论文中使用了多尺度径向基函数 (RBF) 核: 其中 , 是超参数。这种多尺度核有助于捕捉不同尺度上的相似性。
3.1.4. UMAP (Uniform Manifold Approximation and Projection)
UMAP 是一种降维技术,类似于 t-SNE,旨在将高维数据映射到低维空间,同时尽可能保留数据的局部和全局结构。它在可视化高维数据聚类和流形结构方面非常有效。
3.1.5. 皮尔逊相关系数 (Pearson Correlation Coefficient, )
皮尔逊相关系数是一种衡量两个变量之间线性关系强度的统计量。其值介于 -1 和 1 之间,1 表示完全正线性相关,-1 表示完全负线性相关,0 表示没有线性相关。在本文中, 通常指决定系数,它表示因变量中可由自变量解释的方差比例,常用于评估模型预测的准确性。
3.2. 前人工作
- 条件生成模型 (Conditional Generative Models):
- Mirza & Osindero (2014) 和 Ren et al. (2016) 等工作建立了根据潜在向量 和分类变量 生成高维样本的基础。CVAE (Sohn et al., 2015) 是这些模型中的一个简单而有效的分支。
- MMD 在 VAE 和域适应 (Domain Adaptation) 中的应用:
- Louizos et al. (2015) 提出的“变分公平自编码器” (Variational Fair Autoencoder, VFAE) 利用 MMD 来匹配潜在分布 和 ,从而实现无监督域适应 (unsupervised domain adaptation)。
- Lopez et al. (2018b) 将 MMD 用于学习统计独立的潜在维度。
- Long et al. (2015) 和 Tzeng et al. (2014) 等在监督域适应方法中也展示了 MMD 正则化在学习去除域特定信息、保留标签预测特征方面的有效性。
- MMD 在因果推断 (Causal Inference) 中的应用:
- Johansson et al. (2016) 探讨了如何通过学习强制处理组和对照组之间相似性的表示来改进反事实推断 (counterfactual inference),并提及 MMD 作为替代度量。
- 其他样本外转换方法:
- Lotfollahi et al. (2019) 和 Amodio et al. (2018) 曾通过硬编码的潜在空间向量算术 (latent space vector arithmetics) 和直方图匹配 (histogram matching) 来解决样本外转换问题。
3.3. 技术演进
生成模型从最初的非条件生成(如 GANs、VAE)发展到条件生成(如 CVAE、Conditional GANs),允许用户通过指定条件来控制生成内容。然而,这些模型在处理“样本外”条件(即训练集中未直接观察到的条件组合)时表现不佳。MMD 作为一种强大的分布距离度量,被引入到生成模型中,最初常用于潜在空间,以实现域适应或解耦学习。
本文的工作可以看作是这一演进链条中的一个重要创新:它认识到 CVAE 中,虽然潜在空间 倾向于条件无关,但解码器第一层( 层)的表示却会强烈地随条件 变化。通过将 MMD 正则化应用到 层,trVAE 有效地“拉近”了不同条件下 的分布,从而鼓励模型学习更通用的、跨条件的特征转换能力,显著提升了样本外生成性能。这是一种对 CVAE 架构和正则化策略的精巧改进。
3.4. 差异化分析
trVAE 与现有方法的核心区别和创新点如下:
- CVAE (Sohn et al., 2015) 及其变体:
- 核心区别: 传统的 CVAE 没有明确的机制来关联训练过程中的不同条件,导致其在解码器第一层( 空间)的表示在不同条件之间高度分散且不紧凑。trVAE 通过在解码器第一层( 层)引入 MMD 正则化,强制不同条件下的 表示分布趋于一致,从而解决了 CVAE 泛化能力不足的问题。
- MMD-CVAE (MMD on bottleneck, 类似于 VFAE, Louizos et al., 2015):
- 核心区别: VFAE 等方法将 MMD 应用于瓶颈层 (bottleneck layer) 来匹配潜在分布。本文作者通过实验(并在图 5e 和 6e 中展示)发现,在 层进行 MMD 正则化并不能有效改善性能。trVAE 的关键在于将 MMD 应用于解码器第一层 。作者认为 层已经倾向于条件无关,而 层才开始引入条件信息,因此在 层进行正则化更能促进跨条件特征的学习。
- MMD-regularized autoencoder (Dziugaite et al., 2015b; Amodio et al., 2019):
- 核心区别: 这些通常是自编码器 (Autoencoder, AE) 框架下的 MMD 应用,可能不具备 VAE 的变分推断和生成能力,或者其正则化策略与 trVAE 不同。trVAE 结合了 CVAE 的生成能力和特定位置的 MMD 正则化。
- CycleGAN (Zhu et al., 2017):
- 核心区别: CycleGAN 是一种对抗性生成网络 (Generative Adversarial Network, GAN) 架构,通过循环一致性损失 (cycle-consistency loss) 实现无配对图像到图像的转换。GANs 训练过程通常较为不稳定,并且优化困难。trVAE 基于 VAE,训练过程相对简单和稳定,且 MMD 比 GANs 的对抗性训练更容易优化。此外,CycleGAN 主要用于图像转换,trVAE 也在高维表格数据上展示了其有效性。
- scGen (Lotfollahi et al., 2019) 和其他硬编码方法 (Amodio et al., 2018):
- 核心区别: scGen 等方法依赖于硬编码的潜在空间向量算术来进行转换,或者使用直方图匹配。这些方法通常不具备端到端学习能力,难以泛化到多于两个条件,且可能需要领域专家知识来设计硬编码规则。trVAE 是一种完全数据驱动的端到端方法,可以处理多个条件而无需特定设计。
4. 方法论
4.1. 方法原理
trVAE 的核心思想是解决传统 CVAE 在进行样本外生成 (out-of-sample generation) 时,其解码器在不同条件下的中间表示分布不紧凑的问题。作者观察到,在 CVAE 中,潜在空间 往往会学习到与条件 解耦 (disentangled) 的特征,即 几乎独立于 。然而,在解码器的第一层输出 中,由于条件 被显式地输入到解码器, 的表示会强烈地随 变化。这种强烈的条件依赖性导致不同条件下的 分布差异很大,使得模型难以学习到通用的跨条件转换能力,从而限制了其泛化到样本外条件的能力。
为了克服这一限制,trVAE 引入了最大均值差异 (MMD) 正则化。不同于以往在潜在空间 上应用 MMD 来匹配分布,trVAE 将 MMD 应用于解码器的第一层 。其直觉是,通过强制不同条件下的 表示分布(例如, 和 )变得相似,模型被激励去学习那些跨条件通用的、但又能在细微处区分条件的特征。这种正则化促使 空间变得更紧凑,即不同条件下的 仍然存在差异,但这种差异被 MMD 限制在一个可控的范围内。这样,当模型需要将一个样本从源条件转换到目标条件时(例如,从 转换到 ),它能够更好地利用这些紧凑的、共享的特征,从而实现更准确、更具泛化能力的转换。
4.2. 核心方法详解
trVAE 是一个 MMD-正则化 (MMD-regularized) 的 CVAE。它在标准 CVAE 的基础上,通过修改损失函数,在解码器的中间层(具体来说是瓶颈层之后的第一个解码器层)引入了 MMD 损失项。
4.2.1. CVAE 的基本框架与损失函数
首先,我们回顾条件变分自编码器 (CVAE) 的基本概念。CVAE 的目标是建模高维随机变量 在给定条件 下的概率分布 。这通过引入一个潜在变量 来实现: 其中:
-
: 高维观测数据 (high-dimensional observation)。
-
: 条件变量 (categorical variable),可以是标量或低维向量。
-
: 潜在随机变量 (latent random vector)。
-
: 生成模型(解码器)的参数。
-
: 解码器 (decoder) 的生成分布,将 和 解码为 。
-
: 潜在变量 的先验分布 (prior distribution),以 为条件。
为了使上述积分变得可处理,CVAE 引入了一个编码分布 ,它由编码器 (encoder) 参数化。通过最大化对数似然 的下界,即证据下界 (Evidence Lower Bound, ELBO),来训练 CVAE。CVAE 的损失函数 (通常是 ELBO 的负值,或者说是最大化 ELBO) 为: 这里,作者调整了标准 ELBO 的符号和系数,使其表示为需要优化的损失函数:
-
: 编码器的参数。
-
: 解码器的参数。
-
: 这是重建项 (reconstruction term),衡量解码器在给定潜在变量 和条件 的情况下重建原始数据 的能力。通常,这会是负的重建损失(例如负对数似然),因此最大化这个项等价于最小化重建损失。
-
: 这是KL 散度项 (KL divergence term),衡量编码器输出的潜在分布 与先验分布 之间的差异。最小化这个项鼓励潜在空间具有良好的结构。
-
: 重建项的权重超参数。
-
: KL 散度项的权重超参数。
4.2.2. trVAE 的架构与 MMD 正则化位置
trVAE 的架构在编码器 和解码器 之间划分。解码器 又被分解为两部分:第一层 和剩余层 。 模型通过以下步骤进行转换: 其中:
-
: 编码器,将输入数据 和条件 编码为潜在表示 。
-
: 解码器的第一层,接收潜在表示 和条件 ,输出中间表示 。
-
: 解码器的其余层,接收中间表示 ,输出重建数据 。
作者的关键见解是,尽管 趋向于条件无关,但 却与 强相关。在标准 CVAE 中,这种强相关性导致不同条件下的 分布高度分散。为了使 的分布在不同条件间变得更紧凑和相似,trVAE 在解码器第一层 上引入了 MMD 正则化。
MMD 损失函数 用于衡量两个分布的距离。本文中使用的是多尺度 RBF 核 (multi-scale RBF kernel): 其中 , 是核的数量, 是每个核的带宽超参数。
4.2.3. trVAE 的损失函数
trVAE 的损失函数是在 CVAE 损失的基础上,通过复制并添加一个 MMD 项来构建的。假设我们有两组数据样本 (X, S) 和 (X', S'),分别代表不同条件下的数据。
其中:
- : trVAE 的总损失函数。
- : 第一个条件 () 下的 CVAE 损失。
- : 第二个条件 () 下的 CVAE 损失。这里 可以是与 不同的条件。
- : MMD 损失项的权重超参数,控制 MMD 正则化的强度。
- : MMD 损失项。它衡量的是从源条件 得到的中间表示 和从目标条件 得到的中间表示 之间的距离。
- 具体来说, 是将数据 在条件 下编码并由解码器第一层生成的结果。
- 是将数据 在条件 下编码并由解码器第一层生成的结果。
- 在训练时,模型会随机选择批次的样本,并确保批次中包含不同条件的样本(例如 和 ),以便计算 MMD 损失。
4.2.4. 训练与预测过程
-
训练时间 (Training Time):
- 模型接收随机批次的原始数据 及其对应的条件标签 作为输入。
- 这些批次经过分层抽样 (stratified sampling),以确保不同条件 的样本比例大致相等。
- 通过上述 trVAE 损失函数优化编码器 和解码器 的参数 和 。
-
预测时间 (Prediction Time) - 样本外转换:
-
假设要将源条件 下的样本 转换到目标条件 。
-
首先,将源样本 及其条件 传递给编码器 以获得潜在表示 。
-
然后,将这个潜在表示 与目标条件 一起传递给解码器 ,即 。
-
解码器输出的就是转换到目标条件 下的预测样本 。
这种方法允许模型在训练数据中即使没有看到源条件和目标条件的直接配对样本,也能够通过学习到的紧凑中间表示 来实现跨条件转换。
-
下图(原文 Figure 1)展示了 trVAE 的架构:
该图像是论文中图1的示意图,展示了带有最大均值差异(MMD)正则化的条件变分自编码器(CVAE)trVAE架构。网络以输入 和条件 编码,并在解码器第一层通过MMD层正则化条件影响,实现从条件 到 的转换,公式包括 ,,。
Figure 1: The transformer VAE (trVAE) is an MMD-regularized CVAE. It receives randomized batches of data ( x ) and condition (s) as input during training, stratified for approximately equal proportions of . In contrast to a standard CVAE, we regularize the effect of on the representation obtained after the first-layer of the decoder During prediction time, we transform batches of the source condition to the target condition by encoding and decoding .
5. 实验设置
论文通过在一系列高维图像和表格数据上进行实验,验证了 trVAE 的性能。
5.1. 数据集
5.1.1. Morpho-MNIST
-
来源: Castro et al. (2018)
-
特点: 包含 60,000 张“正常”手写数字图像,以及相同数量的“细线 (thin)”和“粗线 (thick)”笔画的数字图像。每个图像都有一个数字类别 (0-9) 和一个笔画粗细条件 (正常、细、粗)。
-
实验设置:
- 训练数据: 使用所有正常笔画的数字数据 (条件 )。在转换条件(细线和粗线笔画,)中,仅保留了数字类别 。
- 目标: 将正常笔画的数字图像转换成训练中未见的细线或粗线笔画的数字(样本外)。
-
数据集样本示例: 下图(原文 Figure 3)展示了 Morpho-MNIST 数据集上的风格迁移结果。
该图像是图表,展示了Morpho-MNIST数据集中数字样本的风格迁移,trVAE模型成功将训练中未见的正常数字转换为细线(a)和粗线(b)风格,有效实现了风格的out-of-sample转化。Figure 3: Out-of-sample style transfer for Morpho-MNIST dataset containing normal, thin and thick digits. trVAE successfully transforms normal digits to thin (a) and thick ( for digits not seen during training (out-of-sample).
5.1.2. CelebA
-
来源: Liu et al. (2015)
-
特点: 包含 202,599 张名人面部图像,每张图像有 40 个二元属性。
-
实验设置:
- 任务: 学习将不笑的脸转换成笑脸。
- 条件: 关注微笑 () 和性别 () 属性。
- 训练数据: 使用了所有男性(笑和不笑)的图像,但仅使用不笑的女性图像。
- 目标: 在训练数据中完全缺乏“微笑女性”样本的情况下,成功将不笑的女性图像转换为笑脸(样本外)。
-
数据集样本示例: 下图(原文 Figure 4)展示了 CelebA 数据集上的风格迁移结果。
该图像是图像生成结果展示,展示了CelebA数据集中不同人物脸部表情的对比。trVAE模型在训练数据缺乏女性微笑样本的情况下,成功为女性脸部添加了微笑,实现了条件下样本外生成。Figure 4: CelebA dataset with images in two conditions: celebrities without a smile and with a smile on their face. trVAE successfully adds a smile on faces of women without a smile despite these samples completely lacking from the training data (out-of-sample). The training data only comprises non-smiling women and smiling and non-smiling men.
5.1.3. 肠道细胞感染响应数据 (Gut Cell Infection Response)
-
来源: Haber et al. (2017)
-
特点: 单细胞基因表达数据集,描述了感染沙门氏菌 (Salmonella) 或多形螺旋线虫 (Heligmosomoides polygyrus, H. poly) 后肠道细胞的响应。
- 包含 8 种不同的细胞类型和 4 种条件:对照/健康细胞 (),H.Poly 感染 3 天后 (),H.Poly 感染 10 天后 (),以及沙门氏菌感染 ()。
- 归一化后的基因表达数据具有 1,000 个维度(对应 1,000 个基因)。
-
实验设置:
- 基准比较: 由于部分基线模型只能处理两种条件,因此仅使用对照和 H.Poly.Day10 条件进行模型比较。
- 样本外: 在训练和验证过程中,排除了 Tuft 感染细胞,因为这些细胞是最难进行样本外泛化的(共享特征最少,训练数据最少)。
- 多条件实验: 另外进行了一个包含所有三种条件(对照、H.Poly.Day3、H.Poly.Day10 和沙门氏菌)的实验,并在所有扰动条件下排除了每一种细胞类型进行训练。
-
数据集样本示例: 下图(原文 Figure 5a)展示了肠道细胞数据的 UMAP 可视化。
该图像是图表,展示了论文中Figure 5各子图的结果:包括(a)胃肠细胞的UMAP条件与细胞类型可视化,(b-c)trVAE预测和真实感染Tuft细胞的基因表达均值与方差的比较,(d)不同模型下Defa24基因表达的分布对比,(e)不同模型预测均值和方差的Pearson相关系数比较,以及(f)trVAE在不同细胞类型和条件下基因表达均值相关系数的统计。各图数据反映模型预测的准确性和泛化能力。Figure 5: (a) UMAP visualization of conditions and cell type for gut cells. (b-c) Mean and variance expression of 1,000 genes comparing trVAE-predicted and real infected Tuft cells together with the top 10 differentiall-expressed genes highlighted in red ( denotes Pearson correlation between ground truth and predicted values). (d) Distribution of Defa24: the top response gene to H.poly.Day10 infection between control, predicted and real stimulated cells for different models. Vertical axis: expression distribution for Defa24. Horizontal axis: control, real and predicted distribution by different models. (e) Comparison of Pearson's values for mean and variance gene expression between real and predicted cells for different models. Center values show the mean of values estimated using random subsamples for the prediction of each model and error bars depict standard deviation. (f) Comparison of values for mean gene expression between real and predicted cells by trVAE for the eight different cell types and three conditions. Center values show the mean of values estimated using random subsamples for each cell type and error bars depict standard deviation.
5.1.4. 外周血单核细胞 (PBMCs) 刺激响应数据
-
来源: Kang et al. (2018)
-
特点: 单细胞基因表达数据集,包含来自 8 名不同狼疮患者的 7,217 个 IFN-β 刺激细胞和 6,359 个对照外周血单核细胞 (PBMCs)。IFN-β 刺激会导致免疫细胞转录谱发生显著变化。
-
实验设置:
- 样本外: 在模型训练期间,排除了自然杀伤 (NK) 细胞进行预测。
-
数据集样本示例: 下图(原文 Figure 6a)展示了 PBMC 数据的 UMAP 可视化。
该图像是论文中图6的多子图图表,展示了外周血单核细胞(PBMCs)的UMAP可视化、trVAE预测和真实自然杀伤细胞(NK)在2000维度上的均值与方差比较、IFN-β刺激后基因ISG15表达分布,以及不同模型在基因表达均值和方差预测中值的比较。Figure 6: (a) UMAP visualization of peripheral blood mononuclear cells (PBMCs). (b-c) Mean and variance per 2,000 dimensions between trVAE-predicted and real natural killer cells (NK) together with the top 10 differentially-expressed genes highlighted in red. (d) Distribution of ISG15: the most strongly changing gene after IFN- perturbation between control, real and predicted stimulated cells for different models. Vertical axis: expression distribution for ISG15. Horizontal axis: control, real and predicted distribution by different models. (e) Comparison of values for mean and variance gene expression between real and predicted cells for different models. Center values show the mean of values estimated using random subsamples for the prediction of each model and error bars depict standard deviation.
5.1.5. 网络架构和超参数
- 图像数据 (Morpho-MNIST 和 CelebA): 使用了卷积层。Morpho-MNIST 使用了基于卷积层的 trVAE,其中笔画宽度通过两个全连接层编码,然后重塑为图像通道。CelebA 使用了深度卷积 trVAE,具有 U-Net 样式的架构,条件标签作为额外的通道输入。
- 基因表达数据 (肠道细胞和 PBMCs): 使用了全连接层。
- 超参数: 每种应用的优化超参数通过网格搜索 (grid-search) 确定。详细超参数列在附录 A 的表格 1-9 中。
5.2. 评估指标
5.2.1. 定性评估 (Qualitative Evaluation)
在图像生成任务中(Morpho-MNIST 和 CelebA),主要通过目视检查 (visual inspection) 来评估生成图像的质量、风格转换的准确性以及原始图像特征的保留程度。
5.2.2. 皮尔逊相关系数 ()
在基因表达数据任务中,使用皮尔逊相关系数 (Pearson Correlation Coefficient) 来量化模型预测的基因表达均值和方差与真实值之间的线性关系强度。论文中用 表示。
-
概念定义: 皮尔逊相关系数衡量了两个连续变量之间线性关系的方向和强度。 值(决定系数)表示因变量中可以由自变量预测的方差比例。在本文中,它用于评估模型预测的均值或方差与真实观测值之间的拟合优度,值越接近 1 表示预测越准确。
-
数学公式: 对于两个变量 和 ,其皮尔逊相关系数 定义为: 决定系数 通常是皮尔逊相关系数的平方: 或者更广义地,对于回归模型,它表示模型解释的方差比例: 其中 是模型对 的预测值。在论文语境中,它似乎指代预测值与真实值之间的皮尔逊相关系数的平方,或者直接是预测值与真实值之间的线性拟合优度。
-
符号解释:
- : 第 个数据点的第一个变量值(例如,模型预测的基因表达值)。
- : 第 个数据点的第二个变量值(例如,真实的基因表达值)。
- : 变量 的样本均值。
- : 变量 的样本均值。
- : 样本数量。
- : 第 个数据点的预测值。
- : 求和符号。
5.3. 对比基线
论文将 trVAE 与以下多种现有方法和替代方案进行了比较:
- Vanilla CVAE (Sohn et al., 2015): 最简单的条件变分自编码器。
- CVAE with MMD on bottleneck (MMD-CVAE): 类似于 VFAE (Louizos et al., 2015),将 MMD 应用于潜在空间 (bottleneck layer) 的 CVAE 变体,用于比较 MMD 放置位置的影响。
- MMD-regularized autoencoder (Dziugaite et al., 2015b; Amodio et al., 2019): 使用 MMD 进行正则化的自编码器。
- CycleGAN (Zhu et al., 2017): 一种用于无配对图像到图像转换的生成对抗网络。
- scGen (Lotfollahi et al., 2019): 结合 VAE 和向量算术 (vector arithmetics) 的单细胞扰动预测模型。
- scVI (Lopez et al., 2018a): 带有负二项式输出分布 (negative binomial output distribution) 的 CVAE,专门用于单细胞转录组数据。
6. 实验结果与分析
6.1. 核心结果分析
6.1.1. 图像风格迁移 (Morpho-MNIST 和 CelebA)
-
Morpho-MNIST 数据集:
-
trVAE 成功地将正常笔画的数字(训练中包含所有数字类别)转换成了训练中仅见过部分数字类别的细线和粗线笔画样式。这表明 trVAE 能够很好地处理样本外 (out-of-sample) 的风格迁移任务。
-
下图(原文 Figure 3)清晰地展示了从正常数字到细线和粗线数字的转换效果,即使对于训练中未见的数字,转换质量也令人满意。
该图像是图表,展示了Morpho-MNIST数据集中数字样本的风格迁移,trVAE模型成功将训练中未见的正常数字转换为细线(a)和粗线(b)风格,有效实现了风格的out-of-sample转化。
Figure 3: Out-of-sample style transfer for Morpho-MNIST dataset containing normal, thin and thick digits. trVAE successfully transforms normal digits to thin (a) and thick ( for digits not seen during training (out-of-sample).
-
-
CelebA 数据集:
-
trVAE 在更复杂的 CelebA 数据集上展示了其能力,成功将不笑的女性面部转换为笑脸,尽管“微笑女性”的样本在训练数据中完全缺失(训练数据只有不笑的女性、笑的男性和不笑的男性)。
-
转换后的图像保留了原始图像的大部分特征,同时准确地添加了微笑表情。这进一步证明了 trVAE 在复杂条件下的样本外生成能力。
-
下图(原文 Figure 4)展示了这一效果,突出显示了模型在处理现实世界复杂数据时的灵活性和有效性,甚至能够适应像 U-Net 这样的知名架构。
该图像是图像生成结果展示,展示了CelebA数据集中不同人物脸部表情的对比。trVAE模型在训练数据缺乏女性微笑样本的情况下,成功为女性脸部添加了微笑,实现了条件下样本外生成。
Figure 4: CelebA dataset with images in two conditions: celebrities without a smile and with a smile on their face. trVAE successfully adds a smile on faces of women without a smile despite these samples completely lacking from the training data (out-of-sample). The training data only comprises non-smiling women and smiling and non-smiling men.
-
6.1.2. 肠道细胞感染响应预测
-
预测准确性: trVAE 在预测受感染 Tuft 细胞的高维基因表达均值和方差方面表现出卓越的准确性。与真实值相比,trVAE 的预测值具有更高的皮尔逊相关系数。
-
关键基因响应: 对于 H.poly 感染后 Tuft 细胞中表达变化最大的基因 Defa24,trVAE 提供了比其他模型更好的均值和方差估计。下图(原文 Figure 5d)展示了 Defa24 基因在不同模型预测下的分布,trVAE 的预测分布与真实受刺激细胞的分布最为接近。
该图像是图表,展示了论文中Figure 5各子图的结果:包括(a)胃肠细胞的UMAP条件与细胞类型可视化,(b-c)trVAE预测和真实感染Tuft细胞的基因表达均值与方差的比较,(d)不同模型下Defa24基因表达的分布对比,(e)不同模型预测均值和方差的Pearson相关系数比较,以及(f)trVAE在不同细胞类型和条件下基因表达均值相关系数的统计。各图数据反映模型预测的准确性和泛化能力。Figure 5: (a) UMAP visualization of conditions and cell type for gut cells. (b-c) Mean and variance expression of 1,000 genes comparing trVAE-predicted and real infected Tuft cells together with the top 10 differentiall-expressed genes highlighted in red ( denotes Pearson correlation between ground truth and predicted values). (d) Distribution of Defa24: the top response gene to H.poly.Day10 infection between control, predicted and real stimulated cells for different models. Vertical axis: expression distribution for Defa24. Horizontal axis: control, real and predicted distribution by different models. (e) Comparison of Pearson's values for mean and variance gene expression between real and predicted cells for different models. Center values show the mean of values estimated using random subsamples for the prediction of each model and error bars depict standard deviation. (f) Comparison of values for mean gene expression between real and predicted cells by trVAE for the eight different cell types and three conditions. Center values show the mean of values estimated using random subsamples for each cell type and error bars depict standard deviation.
-
与其他模型的定量比较: 下图(原文 Figure 5e)显示,trVAE 在预测均值和方差的 值方面显著优于所有对比模型。值得注意的是,CVAE 在瓶颈层施加 MMD 正则化 (MMD-CVAE) 并未带来性能提升,这印证了作者关于 MMD 应置于解码器第一层而非瓶颈层的论点。对于均值预测,trVAE 的 均值接近 0.97,远高于其他模型。
-
多条件处理能力: trVAE 成功地处理了多条件场景,准确预测了所有八种细胞类型在每种扰动条件下的基因表达(下图,原文 Figure 5f)。这对于处理复杂生物学实验数据至关重要,因为许多现有模型仅限于处理两种条件。
6.1.3. PBMCs 刺激响应预测
-
预测均值和方差: trVAE 准确预测了在训练中未见的自然杀伤 (NK) 细胞中所有基因的均值和方差。下图(原文 Figure 6b-c)展示了 trVAE 预测值与真实值之间的高度一致性,特别是对 IFN-β 强烈响应的基因。
该图像是论文中图6的多子图图表,展示了外周血单核细胞(PBMCs)的UMAP可视化、trVAE预测和真实自然杀伤细胞(NK)在2000维度上的均值与方差比较、IFN-β刺激后基因ISG15表达分布,以及不同模型在基因表达均值和方差预测中值的比较。Figure 6: (a) UMAP visualization of peripheral blood mononuclear cells (PBMCs). (b-c) Mean and variance per 2,000 dimensions between trVAE-predicted and real natural killer cells (NK) together with the top 10 differentially-expressed genes highlighted in red. (d) Distribution of ISG15: the most strongly changing gene after IFN- perturbation between control, real and predicted stimulated cells for different models. Vertical axis: expression distribution for ISG15. Horizontal axis: control, real and predicted distribution by different models. (e) Comparison of values for mean and variance gene expression between real and predicted cells for different models. Center values show the mean of values estimated using random subsamples for the prediction of each model and error bars depict standard deviation.
-
特定基因响应: IFN-β 刺激会导致 NK 细胞中 ISG15 基因表达增加。trVAE 即使在训练中从未见过 NK 细胞的刺激响应,也准确预测了 ISG15 表达的这一变化,与真实 NK 细胞中的观察结果一致。下图(原文 Figure 6d)对比了不同模型的 ISG15 表达分布,trVAE 的预测与真实情况最为吻合。
-
与其他模型的定量比较: 下图(原文 Figure 6e)的 值比较再次证实了 trVAE 的最佳性能,在预测基因表达均值和方差方面均优于 CycleGAN、MMD-regularized autoencoder (SAUCIE) 和其他基线模型。
6.1.4. 总结
总的来说,实验结果强有力地证明了 trVAE 的优越性。其关键在于 MMD 正则化在解码器第一层的巧妙应用,使得模型能够学习到紧凑的、跨条件泛化的表示,从而在图像风格迁移和高维生物学数据扰动预测等样本外生成任务中,取得了比现有方法更高的准确性、鲁棒性和泛化能力。尤其是在生物学应用中,trVAE 解决了处理少数类和多条件等复杂场景的难题。
6.2. 数据呈现 (表格)
以下是原文附录 A 中详细的超参数配置和模型架构表格。
以下是原文 Table 1 的结果:
| Name | Operation | NoF/Kernel Dim. | Dropout | BN | Activation | Input |
|---|---|---|---|---|---|---|
| input | (28, 28, 1) | × | × | |||
| conditions | 2 | × | × | |||
| FC-1 | FC | 128 | × | × | Leaky ReLU | conditions |
| FC-2 | FC | 784 | 0.2 | √ | Leaky ReLU | FC-1 |
| FC-2_resh | Reshape | (28, 28, 1) | × | × | × | FC-2 |
| Conv2D_1 | Conv2D | (4, 4, 64, 2) | × | × | Leaky ReLU | [FC-2_resh, input] |
| Conv2D_2 | Conv2D | (4, 4, 64, 64) | × | × | Leaky ReLU | Conv2D_1 |
| FC-3 | FC | 128 | × | √ | Leaky ReLU | Flatten(Conv2D_2) |
| mean | FC | 50 | × | × | Linear | FC-3 |
| var | FC | 50 | × | × | Linear | FC-3 |
| Z | FC | 50 | × | × | Linear | [mean, var] |
| FC-4 | FC | 128 | × | × | Leaky ReLU | conditions |
| FC-5 | FC | 784 | 0.2 | √ | Leaky ReLU | FC-4 |
| FC-5_resh | Reshape | (28, 28, 1) | × | × | × | FC-5 |
| MMD | FC | 128 | × | √ | Leaky ReLU | [z, FC-5_resh] |
| FC-6 | FC | 256 | × | × | Leaky ReLU | MMD |
| FC-7_resh | Reshape | (2, 2, 64) | × | × | × | FC-6 |
| Conv_transp_1 | Conv2D Transpose | (4, 4, 128, 64) | × | × | Leaky ReLU | FC-7_resh |
| Conv_transp_2 | Conv2D Transpose | (4, 4, 64, 64) | × | × | Leaky ReLU | UpSampling2D(Conv_tr |
| Conv_transp_3 | Conv2D Transpose | (4, 4, 64, 64) | × | × | Leaky ReLU | Conv_transp_2 |
| Conv_transp_4 | Conv2D Transpose | (4, 4, 2, 64) | × | × | Leaky ReLU | UpSampling2D(Conv_tr |
| output | Conv2D Transpose | (4, 4, 1, 2) | × | × | ReLU | UpSampling2D(Conv_tr |
| Optimizer | Adam | |||||
| Learning Rate | 0.001 | |||||
| Leaky ReLU slope | 0.2 | |||||
| Batch Size | 1024 | |||||
| # of Epochs | 5000 | |||||
| α | 0.001 | |||||
| β | 1000 |
以下是原文 Table 2 的结果:
| Name | Operation | NoF/Kernel Dim. Dropout BN Activation | Input |
| input | (64, 64, 3) × × | ||
| conditions | 2 × × | ||
| FC-1 | FC | 128 × × ReLU | conditions |
| FC-2 | FC | 1024 0.2 √ ReLU | FC-1 |
| FC-2_reshaped | Reshape | (64, 64, 1) × × × | FC-2 |
| Conv_1 | Conv2D | (3, 3, 64, 4) × × ReLU | [FC-2_reshaped, input] |
| Conv_2 | Conv2D | (3, 3, 64, 64) × × ReLU | Conv_1 |
| Pool_1 | Pooling2D | × × × × | Conv_2 |
| Conv_3 | Conv2D | (3, 3, 128, 64) × × ReLU | Pool_1 |
| Conv_4 | Conv2D | (3, 3, 128, 128) × × ReLU | Conv_3 |
| Pool_2 | Pooling2D | × × × × | Conv_4 |
| Conv_5 | Conv2D | (3, 3, 256, 128) × × ReLU | Pool_2 |
| Conv_6 | Conv2D | (3, 3, 256, 256) × × ReLU | Conv_5 |
| Conv_7 | Conv2D | (3, 3, 256, 256) × × ReLU | Conv_6 |
| Pool_3 | Pooling2D | × × × × | Conv_7 |
| Conv_8 | Conv2D | (3, 3, 512, 256) × × ReLU | Pool_3 |
| Conv_9 | Conv2D | (3, 3, 512, 512) × × ReLU | Conv_8 |
| Conv_10 | Conv2D | (3, 3, 512, 512) × × ReLU | Conv_9 |
| Pool_4 | Pooling2D | × × × × | Conv_10 |
| Conv_11 | Conv2D | (3, 3, 512, 256) × × ReLU | Pool_4 |
| Conv_12 | Conv2D | (3, 3, 512, 512) × × ReLU | Conv_11 |
| Conv_13 | Conv2D | (3, 3, 512, 512) × × ReLU | Conv_12 |
| Pool_4 | Pooling2D | × × × × | Conv_13 |
| flat | Flatten | × × × × | Pool_4 |
| FC-3 | FC | 1024 × × ReLU | flat |
| FC-4 | FC | 256 0.2 × ReLU | FC-3 |
| mean | FC | 60 × × Linear | FC-4 |
| var | FC | 60 × × Linear | FC-4 |
| Z-sample | FC | 60 × × Linear | [mean, var] |
| FC-5 | FC | 128 × × ReLU | conditions |
| MMD | FC | 256 × √ ReLU | [z-sample, FC-5] |
| FC-6 | FC | 1024 × × ReLU | MMD |
| FC-7 | FC | 4096 × × ReLU | FC-6 |
| FC-7_reshaped | Reshape | × × × FC-7 | |
| Conv_transp_1 | Conv2D Transpose | (3, 3, 512, 512) × × ReLU | FC-7_reshaped |
| Conv_transp_2 | Conv2D Transpose | (3, 3, 512, 512) × × ReLU | Conv_transp_1 |
| Conv_transp_3 | Conv2D Transpose | (3, 3, 512, 512) × × ReLU | Conv_transp_2 |
| up_sample_1 | UpSampling2D | × × × × | Conv_transp_3 |
| Conv_transp_4 | Conv2D Transpose | (3, 3, 512, 512) × × ReLU | up_sample_1 |
| Conv_transp_5 | Conv2D Transpose | (3, 3, 512, 512) × × ReLU | Conv_transp_4 |
| Conv_transp_6 | Conv2D Transpose | (3, 3, 512, 512) × × ReLU | Conv_transp_5 |
| up_sample_2 | UpSampling2D | × × × × | Conv_transp_6 |
| Conv_transp_7 | Conv2D Transpose | (3, 3, 128, 256) × × ReLU | up_sample_2 |
| Conv_transp_8 | Conv2D Transpose | (3, 3, 128, 128) × × ReLU | Conv_transp_7 |
| up_sample_3 | UpSampling2D | × × × × | Conv_transp_8 |
| Conv_transp_9 | Conv2D Transpose | (3, 3, 64, 128) × × ReLU | up_sample_3 |
| Conv_transp_10 | Conv2D Transpose | (3, 3, 64, 64) × × ReLU | Conv_transp_9 |
| output Ontim | Conv2D Transpose | (1, 1, 3, 64) × × ReLU | Conv_transp_10 |
以下是原文 Table 3 的结果:
| Name | Operation | NoF/Kernel Dim. | Dropout | BN | Activation | Input |
| input | input_dim n_conditions | × | × | |||
| conditions | 800 | × | × | |||
| FC-1 | FC | 0.2 | Leaky ReLU | [input, conditions] | ||
| FC-2 | FC | 800 | 0.2 | >>> | Leaky ReLU | FC-1 |
| FC-3 | FC | 128 | 0.2 | Leaky ReLU | FC-2 | |
| mean | FC | 50 | × | × | Linear | FC-3 |
| var | FC | 50 | × | × | Linear | FC-3 |
| z-sample | FC | 50 | × | × | Linear | [mean, var] |
| MMD | FC | 128 | 0.2 | Leaky ReLU | [z-sample, conditions] | |
| FC-4 | FC | 800 | 0.2 | √ | Leaky ReLU | MMD |
| FC-5 | FC FC | 800 | 0.2 | √ | Leaky ReLU | FC-3 |
| output Optimizer | Adam | input_dim | × | × | ReLU | FC-4 |
| Learning Rate | 0.001 | |||||
| Leaky ReLU slope | 0.2 | |||||
| Batch Size | 512 | |||||
| 5000 | ||||||
| # of Epochs | 0.00001 | |||||
| α β | 100 | |||||
| η | 100 |
以下是原文 Table 4 的结果:
| Name | Operation | NoF/Kernel Dim. | Dropout | BN | Activation | Input |
| input | input_dim | × | ||||
| FC-1 | FC | 800 | 0.2 | Leaky ReLU | input | |
| FC-2 | FC | 800 | 0.2 | X>>> | Leaky ReLU | F-1 |
| FC-3 | FC | 128 | 0.2 | Leaky ReLU | FC-2 | |
| mean | FC | 100 | × | × | Linear | FC-3 |
| var | FC | 100 | × | × | Linear | FC-3 |
| Z | FC | 100 | × | × | Linear | [mean, var] |
| MMD | FC | 128 | 0.2 | Leaky ReLU | Z | |
| FC-4 | FC | 800 | 0.2 | >>> | Leaky ReLU | MMD |
| FC-5 output | FC FC | 800 input_dim | 0.2 | Leaky ReLU | FC-3 | |
| Optimizer | Adam | × | × | ReLU | FC-4 | |
| Learning Rate | 0.001 | |||||
| Leaky ReLU slope | 0.2 | |||||
| Batch Size | 32 | |||||
| # of Epochs | 300 | |||||
| α | 0.00050 | |||||
| β | 100 | |||||
| 100 | ||||||
| η |
以下是原文 Table 5 的结果:
| Name | Operation | NoF/Kernel Dim. | Dropout | BN | Activation | Input |
| input | input_dim | × | × | |||
| conditions | 1 | × | × | |||
| FC-1 | FC | 800 | 0.2 | √ | Leaky ReLU | [input, conditions] |
| FC-2 | FC | 800 | 0.2 | Leaky ReLU | FC-1 | |
| FC-3 | FC | 128 | 0.2 | √ | Leaky ReLU | FC-2 |
| mean | FC | 50 | × | × | Linear | FC-3 |
| var | FC | 50 | × | × | Linear | FC-3 |
| Z-sample MMD | FC | 50 | × | Linear | [mean, var] | |
| FC-4 | FC FC | 128 | 0.2 | ×> | Leaky ReLU | [z-sample, conditions] |
| FC-5 | FC | 800 | 0.2 | >> | Leaky ReLU | MMD |
| output | FC | 800 input_dim | 0.2 | Leaky ReLU ReLU | FC-3 | |
| Optimizer | Adam | × | × | FC-4 | ||
| Learning Rate | 0.001 | |||||
| Leaky ReLU slope | 0.2 | |||||
| Batch Size | 512 | |||||
| # of Epochs | 300 | |||||
| α | 0.001 |
以下是原文 Table 6 的结果:
| Name | Operation | NoF/Kernel Dim. | Dropout | BN | Activation | Input |
| input | input_dim | × | × | |||
| conditions | 1 | × | × | |||
| FC-1 | FC | 800 | 0.2 | √ | Leaky ReLU | [input, conditions] |
| FC-2 | FC | 800 | 0.2 | Leaky ReLU | FC-1 | |
| FC-3 | FC | 128 | 0.2 | >> | Leaky ReLU | FC-2 |
| mean | FC | 50 | × | × | Linear | FC-3 |
| var | FC | 50 | × | × | Linear | FC-3 |
| Z-sample | FC | 50 | × | × | Linear | [mean, var] |
| MMD | FC | 128 | 0.2 | Leaky ReLU | [z-sample, conditions] | |
| FC-4 | FC | 800 | 0.2 | >>> | Leaky ReLU | MMD |
| FC-5 | FC | 800 | 0.2 | Leaky ReLU | FC-3 | |
| output | FC | input_dim | × | × | ReLU | FC-4 |
| Optimizer | Adam | |||||
| Learning Rate | 0.001 | |||||
| Leaky ReLU slope | 0.2 | |||||
| Batch Size | 512 500 | |||||
| # of Epochs | 0.001 | |||||
| α | ||||||
| β | 1 |
以下是原文 Table 7 的结果:
| Name | Operation | NoF/Kernel Dim. | Dropout | BN | Activation | Input |
| input F-1 | FC | input_dim 700 | × 0.5 | Leaky ReLU | input | |
| FC-2 | FC | 100 | 0.5 | FC-1 | ||
| FC-3 | FC | 50 | 0.5 | >>>>> | Leaky ReLU | |
| FC-4 | FC | 100 | 0.5 | Leaky ReLU | FC-2 | |
| FC-5 | FC | 700 | 0.5 | Leaky ReLU | FC-3 | |
| generator_out | FC | 6,998 | × | Leaky ReLU | FC-4 | |
| FC-6 | FC | 700 | 0.5 | >>> | ReLU | FC-5 |
| FC-7 | FC | 100 | 0.5 | Leaky ReLU | generator_out | |
| discriminator_out | FC | 1 | Leaky ReLU | FC-6 | ||
| Generator Optimizer | Adam | × | × | Sigmoid | FC-7 | |
| Discriminator Optimizer | Adam | |||||
| Optimizer | Adam | |||||
| Learning Rate | 0.001 | |||||
| Leaky ReLU slope | 0.2 | |||||
| # of Epochs | 1000 |
以下是原文 Table 8 的结果:
| Name | Operation | NoF/Kernel Dim. | Dropout | BN | Activation | Input |
|---|---|---|---|---|---|---|
| input conditions | input_dim | × | × | |||
| 1 | × | × | - | |||
| FC-1 | FC | 128 | 0.2 | √ | ReLU | input |
| mean | FC | 10 | × | × | Linear | FC-1 |
| var | FC | 10 | × | × | Linear | FC-1 |
| Z | FC | 10 | × | × | Linear | [mean, var] |
| FC-2 | FC | 128 | 0.2 | √ | ReLU | [z, conditions] |
| output Optimizer | FC Adam | input_dim | × | × | ReLU | FC-2 |
以下是原文 Table 9 的结果:
| Name | Operation | NoF/Kernel Dim. | Dropout | BN | Activation | Input |
|---|---|---|---|---|---|---|
| input | input_dim | × | × | |||
| conditions | 1 | × | × | |||
| FC-1 | FC | 512 | × | √ | Leaky ReLU | [input, conditions] |
| FC-2 | FC | 256 | × | × | Leaky ReLU | FC-1 |
| FC-3 | FC | 128 | × | × | Leaky ReLU | FC-2 |
| FC-4 | 20 | × | × | Leaky ReLU | FC-3 | |
| FC-5 | FC | 128 | × | × | Leaky ReLU | FC-4 |
| FC-6 | FC | 256 | × | × | Leaky ReLU | FC-5 |
| FC-7 | FC | 512 | × | × | Leaky ReLU | FC-6 |
| output | FC | input_dim | × | × | ReLU | FC-4 |
| Optimizer | Adam | |||||
| Learning Rate | 0.001 | |||||
| Leaky ReLU slope | 0.2 256 | |||||
| Batch Size | ||||||
| # of Epochs | 1000 |
7. 总结与思考
7.1. 结论总结
本文提出了 transformer VAE (trVAE),一种通过在条件变分自编码器 (CVAE) 解码器第一层引入最大均值差异 (MMD) 正则化,以实现无配对数据的条件样本外生成的新型模型。核心思想是强制不同条件下的中间表示 具有紧凑和相似的分布,从而激励模型学习跨条件通用的特征转换能力。实验结果表明,trVAE 在 Morpho-MNIST 和 CelebA 等图像风格迁移任务上取得了高质量的样本外生成效果,在单细胞基因表达数据上,对细胞扰动响应的预测准确性显著优于现有方法,特别是在处理少数类和多条件场景时表现出色。定量分析显示,trVAE 将高维估计均值和方差与其真实值之间的皮尔逊相关系数分别从 0.89 提高到 0.97 和从 0.75 提高到 0.87。这证明了 trVAE 在鲁棒性和准确性方面均超越了多种基线模型。
7.2. 局限性与未来工作
作者在讨论部分指出了以下局限性和未来的研究方向:
- 更深层正则化的探索: 尽管 trVAE 在解码器第一层应用 MMD 取得了成功,但作者提到在更深的解码器层进行正则化可能是有益的,然而这在数值上成本更高且可能不稳定,因为表示维度会更高。这方面仍需未来系统性地研究。
- 更大规模数据的应用: 未来工作将关注 trVAE 在更大、更多数据上的应用,特别是涉及条件之间的交互效应 (interaction effects) 的场景。
- 药物相互作用效应: 一个重要的应用领域是药物相互作用效应的研究,这与 Amodio et al. (2018) 的工作相关。
- 与因果推断模型的联系: 进一步的概念性研究将探索 trVAE 与受因果推断启发模型(如 CEVAE, Louizos et al., 2017)之间的联系。作者认为,对干预分布 (interventional distribution) 的忠实建模可以被重新定义为跨域 (across domains) 成功预测扰动效应。
7.3. 个人启发与批判
7.3.1. 个人启发
- MMD 放置位置的精妙: 这篇论文最大的启发在于其对 MMD 正则化位置的独到见解。以往很多工作倾向于在潜在空间 应用 MMD 以实现域不变性 (domain invariance) 或解耦。然而,作者通过实验和理论分析,巧妙地将其放置在解码器的第一层 。这种“下游”的正则化能够更好地平衡条件依赖性和跨条件共享特征的需求,从而更有效地促进样本外泛化。这提示我们,在设计深度学习模型时,对损失函数和正则化项的精确放置,可能比简单堆叠复杂技术更为关键。
- 跨领域泛化能力: trVAE 在图像数据和高维基因表达数据上都取得了显著成功,展示了其方法的通用性和强大的泛化能力。这对于生物学等领域具有重要意义,因为很多生物学问题都涉及对未见条件的预测,而传统方法往往难以处理。
- 问题拆解的思路: 论文通过将 CVAE 泛化能力不足的原因归结为解码器中间表示的非紧凑性,并提出针对性的正则化方案,体现了对复杂问题进行有效拆解和解决的思路。
7.3.2. 批判与潜在改进
- “无配对数据”的更明确阐述: 论文标题提及“unpaired data”,但在方法论中,更侧重于处理训练数据中未见的“条件组合”。虽然这隐含了“无配对”,但如果能更明确地阐释 trVAE 如何处理不同条件数据量不平衡,或条件之间缺乏直接对应样本的情况,将有助于读者更深入理解其“无配对”能力。例如,是否存在一种机制来处理源域和目标域之间数据量极度不平衡的情况?
- MMD 参数 的敏感性分析: 论文使用了多尺度 RBF 核,其中包含超参数 。虽然附录提供了最终使用的超参数,但缺乏对这些 MMD 相关超参数(如 的选择、 的权重)敏感性的深入分析。不同的 值可能导致 MMD 关注不同的尺度,从而影响特征学习。对其进行更系统的研究将增强模型的鲁棒性和可解释性。
- 定性评估的量化: 对于图像生成任务,目前主要是定性评估。虽然视觉效果直观,但引入一些量化指标(如 FID, Inception Score 等)可以更客观地评估生成图像的质量和多样性,从而与更多图像生成基线进行全面比较。
- 计算成本与可扩展性: MMD 损失的计算复杂度较高,尤其是在处理大批量数据时,其计算量随样本数量的平方增长。虽然论文在实验中取得了成功,但对于更大规模的数据集,其计算效率和可扩展性可能是一个挑战。可以探讨一些近似 MMD 或更高效的 MMD 变体。
- “瓶颈层 MMD 无效”的理论深度: 论文提到在瓶颈层 应用 MMD 没有改善性能,并提供了解释。如果能从理论上更深入地分析为什么 趋向于条件无关,以及为何 层更适合 MMD 正则化,将进一步提升论文的理论贡献。
相似论文推荐
基于向量语义检索推荐的相关论文。