Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search
TL;DR 精炼摘要
Jet-Nemotron通过后神经架构搜索冻结预训练全注意力模型中的MLP权重,优化注意力块设计,实现混合架构语言模型。该方法提升生成吞吐量达53.6倍,准确率匹配或超越主流模型,且具备硬件感知调参能力,显著提升大规模语言模型效率。
摘要
We present Jet-Nemotron, a new family of hybrid-architecture language models, which matches or exceeds the accuracy of leading full-attention models while significantly improving generation throughput. Jet-Nemotron is developed using Post Neural Architecture Search (PostNAS), a novel neural architecture exploration pipeline that enables efficient model design. Unlike prior approaches, PostNAS begins with a pre-trained full-attention model and freezes its MLP weights, allowing efficient exploration of attention block designs. The pipeline includes four key components: (1) learning optimal full-attention layer placement and elimination, (2) linear attention block selection, (3) designing new attention blocks, and (4) performing hardware-aware hyperparameter search. Our Jet-Nemotron-2B model achieves comparable or superior accuracy to Qwen3, Qwen2.5, Gemma3, and Llama3.2 across a comprehensive suite of benchmarks while delivering up to 53.6x generation throughput speedup and 6.1x prefilling speedup. It also achieves higher accuracy on MMLU and MMLU-Pro than recent advanced MoE full-attention models, such as DeepSeek-V3-Small and Moonlight, despite their larger scale with 15B total and 2.2B activated parameters.
思维导图
论文精读
中文精读
1. 论文基本信息
1.1. 标题
Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search
1.2. 作者
Yuxian Gu, Qinghao Hu, Shang Yang, Haocheng Xi, Junyu Chen, Song Han, Han Cai。 所有作者均隶属于 NVIDIA。
1.3. 发表期刊/会议
本文作为预印本发表于 arXiv 平台。其在相关领域具有重要影响力,被广泛关注。
1.4. 发表年份
2025年8月21日(UTC时间)。
1.5. 摘要
我们提出了 Jet-Nemotron,这是一个新型的混合架构语言模型 (hybrid-architecture language models) 家族。该模型在准确率上能够匹配或超越领先的全注意力模型 (full-attention models),同时显著提升了生成吞吐量 (generation throughput)。Jet-Nemotron 的开发得益于后神经架构搜索 (Post Neural Architecture Search, PostNAS),这是一个新颖的神经架构探索流水线 (neural architecture exploration pipeline),能够实现高效的模型设计。与以往方法不同,PostNAS 从一个预训练的全注意力模型 (full-attention model) 开始,并冻结 (freezes) 其多层感知器 (MLP) 权重 (MLP weights),从而实现对注意力块设计 (attention block designs) 的高效探索。该流水线包含四个关键组成部分:(1) 学习最优的全注意力层放置和消除 (full-attention layer placement and elimination),(2) 线性注意力块选择 (linear attention block selection),(3) 设计新的注意力块 (attention blocks),以及 (4) 执行硬件感知超参数搜索 (hardware-aware hyperparameter search)。我们的 Jet-Nemotron-2B 模型在全面的基准测试套件中,实现了与 Qwen3、Qwen2.5、Gemma3 和 Llama3.2 相当或更优的准确率,同时提供了高达 53.6 倍的生成吞吐量加速和 6.1 倍的预填充加速。尽管 DeepSeek-V3-Small 和 Moonlight 等近期先进的混合专家 (MoE) 全注意力模型 (MoE full-attention models) 规模更大,总参数量达 15B 且激活参数 (activated parameters) 为 2.2B,Jet-Nemotron-2B 在 MMLU 和 MMLU-Pro 上的准确率仍高于它们。
1.6. 原文链接
- 论文原文链接: https://arxiv.org/abs/2508.15884
- PDF 链接: https://arxiv.org/pdf/2508.15884v3.pdf
- 发布状态: 预印本 (arXiv)
2. 整体概括
2.1. 研究背景与动机
大型语言模型 (Large Language Models, LLMs) 在各种任务中展现出卓越的准确性,引领了人工智能的转型。然而,它们固有的计算 (computational) 和内存 (memory) 需求巨大,尤其是在长文本生成 (long-context generation) 和推理 (reasoning) 场景中,这已成为一个显著的效率问题。传统的 自注意力机制 (self-attention mechanism) 带来了 的计算复杂度 (computational complexity),并且生成了庞大的键值缓存 (Key-Value (KV) cache),使得模型的部署和运行成本高昂。
为了解决这一挑战,研究界在设计更高效的 LLM 架构方面付出了巨大努力,包括开发具有 复杂度的注意力机制 (attention mechanisms)(如线性注意力 (linear attention)),以及构建结合全注意力 (full-attention) 和线性注意力 (linear attention) 的混合模型 (hybrid models),以在准确性和效率之间取得平衡。
然而,现有高效或混合架构模型的一个核心痛点是,它们的准确率在许多挑战性基准测试(如 MMLU(-Pro)、数学推理、检索、编码和长文本任务)上显著落后于最先进的 (state-of-the-art, SOTA) 全注意力模型。这意味着在追求效率的同时,往往不得不牺牲模型能力。
这篇论文的切入点正是针对这一挑战:如何在不牺牲甚至超越 SOTA 全注意力模型准确率的前提下,大幅提升 LLM 的效率。作者认为,现有的架构探索方法成本高昂且风险大,阻碍了 LLM 架构的创新。
2.2. 核心贡献/主要发现
本文提出了 Jet-Nemotron 模型家族,并通过其背后创新的后神经架构搜索 (Post Neural Architecture Search, PostNAS) 流水线,实现了以下核心贡献:
- 创新架构探索范式 PostNAS:
- 引入了一种新颖的
LLM架构探索范式 (architecture exploration paradigm),通过重用预训练的 LLM (reusing pre-trained LLMs) 并冻结其 MLP 权重 (freezing their MLP weights),显著降低了架构探索的成本和风险。这使得LLM架构创新能够更快、更高效地进行。 PostNAS提供了一种系统化的方法,包括全注意力层放置和消除 (full-attention layer placement and elimination)、线性注意力块选择 (linear attention block selection)、新注意力块设计 (new attention block design) 和硬件感知架构搜索 (hardware-aware architecture search)。
- 引入了一种新颖的
- 新型线性注意力块 JetBlock:
- 提出了
JetBlock,这是一种新颖的线性注意力块 (linear attention block),它将线性注意力 (linear attention) 与动态卷积 (dynamic convolution) 相结合。 JetBlock在保持可比生成吞吐量 (generation throughput) 的同时,在数学推理和检索等任务上,其准确率显著优于Mamba2、GLA和Gated DeltaNet等现有设计。
- 提出了
- 高效且高精度模型家族 Jet-Nemotron:
- 推出了
Jet-Nemotron混合架构语言模型 (hybrid-architecture LM) 家族。 Jet-Nemotron在广泛的基准测试中,实现了与SOTA全注意力模型(如Qwen2.5、Qwen3、Gemma3和Llama3.2)相当或更优的准确率。- 同时,它提供了显著更高的生成吞吐量 (generation throughput),例如,在
H100 GPU上、256K 上下文长度下,Jet-Nemotron-2B实现了高达53.6倍的生成吞吐量加速和6.1倍的预填充加速。
- 推出了
- 对高效 LLM 架构设计的洞察:
-
揭示了注意力层在不同任务中的任务特定重要性 (task-specific importance)。
-
发现键值缓存 (KV cache) 大小是影响生成吞吐量 (generation throughput) 的更关键因素,而不是参数量 (parameter count)。
这些贡献共同使得
Jet-Nemotron成为一种在准确性和推理效率方面都非常强大的LLM,具有广泛的实际应用价值。
-
3. 预备知识与相关工作
3.1. 基础概念
为了更好地理解 Jet-Nemotron 及其提出的 PostNAS 方法,我们需要首先了解一些核心概念。
3.1.1. 语言模型 (LMs)
语言模型 (Language Models, LMs) 是一种能够理解、生成和处理人类语言的计算模型。它们通过学习大量文本数据中的模式和结构,预测下一个词元(token)或完成各种语言任务,如文本摘要、问答、翻译等。
3.1.2. Transformer 模型
Transformer 模型 (Transformer model) [8] 是一种开创性的神经网络架构,由 Vaswani 等人于2017年提出,彻底改变了自然语言处理领域。它摒弃了传统的循环神经网络 (RNN) 和卷积神经网络 (CNN) 结构,完全依赖自注意力机制 (self-attention mechanism) 来处理序列数据。Transformer 模型通常由编码器和解码器组成,每个部分都包含多个相同的层,每个层又包含多头自注意力 (multi-head self-attention) 子层和前馈神经网络 (feed-forward network)(或称为多层感知器 (MLP))子层。
3.1.3. 自注意力机制 (Self-Attention Mechanism)
自注意力机制 (Self-Attention Mechanism) 是 Transformer 模型的核心。它允许模型在处理序列中的某个词元时,能够同时考虑到序列中的所有其他词元,并根据它们的重要性分配不同的权重。这使得模型能够捕捉长距离依赖关系。
自注意力机制的计算通常涉及三个关键向量:查询 (Query, )、键 (Key, ) 和值 (Value, )。对于输入序列中的每个词元,都会生成一个 、一个 和一个 向量。
注意力分数的计算公式如下: 其中:
-
:查询矩阵,维度为 ,其中 是序列长度, 是查询向量的维度。
-
:键矩阵,维度为 。
-
:值矩阵,维度为 ,其中 是值向量的维度。
-
:通过计算查询和键的点积来衡量它们之间的相似度或相关性。
-
:缩放因子,用于防止点积结果过大,导致
softmax函数梯度过小。 -
:将注意力分数归一化为概率分布,确保所有权重之和为 1。
-
:加权求和的值矩阵,表示了不同词元对当前词元的重要性贡献。
自注意力机制的计算复杂度为 ,其中 是序列长度。这意味着随着输入序列长度的增加,计算成本呈二次方增长,这在大规模
LLM和长文本处理中成为一个巨大的瓶颈。
3.1.4. 多层感知器 (MLP)
多层感知器 (Multi-Layer Perceptron, MLP),也称为前馈网络 (Feed-Forward Network),是神经网络中最基本的形式之一。在 Transformer 模型中,每个注意力子层之后通常都会有一个 MLP 子层,用于对注意力层的输出进行非线性变换和特征提取。MLP 通常由两个线性变换层和一个非线性激活函数组成。
3.1.5. 键值缓存 (KV Cache)
在自回归的 LLM 推理过程中(例如,逐词生成文本),为了避免重复计算已经生成词元的 键 (Key) 和 值 (Value) 向量,这些向量会被存储起来,形成键值缓存 (KV Cache)。这样,在生成下一个词元时,只需要计算新词元的 向量,并与缓存中所有先前词元的 和 向量进行注意力计算。虽然 KV Cache 提高了推理速度,但随着序列长度的增加,它会消耗大量的内存,特别是在处理长文本时,这成为了另一个主要的效率瓶颈。
3.1.6. 线性注意力 (Linear Attention)
线性注意力 (Linear Attention) 是对传统自注意力机制的一种改进,旨在将其计算复杂度从 降低到 ,即与序列长度呈线性关系。这通常通过对 和 的交互方式进行修改来实现,例如使用核函数 (kernel functions) 将 和 映射到更高维空间进行点积,或者通过状态空间模型 (State Space Model, SSM) 的方式来建模长距离依赖。常见的线性注意力模型包括 RWKV、RetNet、Mamba、GLA 等。虽然这些方法提高了效率,但有时可能难以完全捕捉传统自注意力机制的复杂表达能力,导致在某些任务上准确率下降。
3.1.7. 神经架构搜索 (NAS)
神经架构搜索 (Neural Architecture Search, NAS) 是一种自动化设计神经网络架构的技术。它的目标是自动发现针对特定任务或硬件最优的神经网络结构,而不是通过人工设计。NAS 通常涉及定义一个搜索空间、一个搜索策略和一个性能评估策略。传统 NAS 在图像识别等领域取得了显著成功,但在 LLM 领域,由于巨大的预训练成本,其应用受到了限制。
3.2. 前人工作
- 高效
LM架构: 大量工作致力于开发具有降低 复杂度的注意力机制,包括RWKV7[10]、RetNet[12]、Mamba2[50]、GLA[11]、Deltanet[51] 和Gated DeltaNet[32] 等。这些模型旨在通过不同的机制(如循环神经网络 (RNN) 范式、状态空间模型 (SSM) 或特定核函数 (kernel functions))来克服 的效率瓶颈。 - 混合模型 (Hybrid Models): 另一些研究则专注于构建结合了全注意力 (full attention) 和线性注意力 (linear attention) 的混合架构,以平衡准确性和效率。这类模型包括
Hymba[44]、Zamba2[16] 等。它们通常尝试在模型的不同层或不同部分使用不同的注意力机制。 NAS在LLM领域的应用:NAS[45, 97, 98, 99, 100] 是一种强大的架构探索技术。然而,由于LLM预训练成本极高,NAS很少被直接应用于LLM架构设计。近期的一些努力主要集中在构建灵活的LLM架构 (flexible LLM architectures) [103, 104],它们可以生成不同深度和宽度的子网络以适应不同硬件平台,但这些子网络的基础架构仍然完全依赖于全注意力层。- 线性化
LLM: 还有一些工作致力于用线性注意力替代全注意力来线性化LLM(linearizing LLMs) [89, 90, 91, 92, 93, 94, 95, 96]。然而,这些模型的架构通常优化不足,因为评估特定配置的开销巨大,导致其结果仍逊于SOTA全注意力模型。
3.3. 技术演进
LLM 领域的技术演进可以概括为从注重模型能力(主要由 Transformer 全注意力驱动)到兼顾能力与效率的转变:
- Transformer 模型的崛起 (O(n^2) 全注意力):
Transformer及其核心的自注意力机制证明了其在捕捉长距离依赖方面的强大能力,成为LLM的基石。 - 效率瓶颈的出现: 随着模型规模和上下文长度的增长, 的计算复杂度导致了严重的效率和内存问题。
- 线性注意力模型的探索 (O(n) 复杂度): 研究者开始探索各种替代方案,如
RNN变体、SSM和基于核函数的线性注意力,旨在将复杂度降低到 ,以提高效率。然而,这些模型通常在准确率上无法完全匹敌全注意力模型。 - 混合架构的出现: 为了弥补线性注意力在准确率上的不足,同时利用其效率优势,混合架构模型应运而生,尝试结合两种机制的优点。
NAS的引入:NAS最初在其他领域取得成功,但因LLM预训练成本高昂,其在LLM架构设计中的应用受限。PostNAS的创新: 本文的PostNAS正是在这一背景下提出的,它通过利用预训练模型并冻结MLP权重,大幅降低了架构探索的成本,从而为LLM架构的创新提供了一条新的、更可行的路径,使得高效且高精度的混合架构模型(如Jet-Nemotron)成为可能。
3.4. 差异化分析
Jet-Nemotron 及其 PostNAS 方法与现有工作的主要区别和创新点体现在以下几个方面:
-
与传统
NAS的区别:- 起点不同: 传统
NAS通常从头开始搜索和训练模型,这对于LLM而言预训练成本极高且风险大。PostNAS则从一个预训练的全注意力模型开始。 - 探索效率:
PostNAS冻结了预训练模型的多层感知器 (MLP) 权重,只对注意力块 (attention block) 进行搜索和优化。这一策略显著减少了训练成本和数据需求,使得架构探索变得高效可行。 - 风险降低: 由于利用了已验证的预训练模型能力,
PostNAS降低了探索全新架构可能失败的风险。
- 起点不同: 传统
-
与现有高效 模型(如
RWKV7,Mamba2等)的区别:- 准确率: 现有纯粹的 线性注意力模型通常在复杂任务(如
MMLU(-Pro)、数学推理)上,准确率显著低于SOTA全注意力模型。Jet-Nemotron通过混合架构 (hybrid-architecture) 和PostNAS的精细优化,能够匹配甚至超越这些全注意力模型的准确率。 - 架构优化深度:
Jet-Nemotron的架构不是简单地替换注意力块,而是通过PostNAS流程,系统地进行全注意力层放置优化 (full-attention layer placement optimization)、线性注意力块选择 (linear attention block selection) 和硬件感知搜索 (hardware-aware search),确保了整体架构的性能和效率最优。
- 准确率: 现有纯粹的 线性注意力模型通常在复杂任务(如
-
与现有混合模型(如
Hymba,Zamba2等)的区别:-
优化方法: 尽管同为混合架构,但
Jet-Nemotron的混合策略是通过PostNAS自动化搜索而非人工经验决定的。这使得其混合架构更具针对性和优化性。 -
性能提升:
Jet-Nemotron在准确率和效率上的提升幅度更大,尤其是在长上下文和吞吐量方面,显著优于现有的混合模型。例如,Jet-Nemotron在层级上混合注意力机制,而同期工作Falcon-H1则采用头级混合,本文指出层级混合在并行度上更有优势,从而带来更高的吞吐量。 -
新注意力块:
Jet-Nemotron引入了创新的JetBlock,结合了动态卷积 (dynamic convolution),进一步提升了线性注意力部分的表达能力,这是现有混合模型所不具备的。总结来说,
Jet-Nemotron的核心创新在于提供了一种高效、系统且硬件感知的架构探索流水线PostNAS,使得研究者能够在继承预训练LLM强大能力的基础上,设计出既高效又高准确率的混合架构模型,有效填补了现有高效LLM在准确率上的空白。
-
4. 方法论
4.1. 方法原理
PostNAS 的核心思想是,鉴于从头开始预训练 LLM 的高昂成本,我们可以通过利用现有的预训练全注意力模型 (leveraging existing pre-trained full-attention models) 来大幅降低架构探索的门槛。其直觉在于,这些预训练模型中的多层感知器 (MLP) 已经学习到了丰富的语义和知识表示,可以被冻结 (frozen) 并重用。因此,架构探索的重点可以放在注意力块的设计 (attention block design) 上,即如何高效地替换或调整 Transformer 中的注意力层,以提高效率而不牺牲准确性。
PostNAS 的方法原理基于以下几个关键观察:
-
MLP 知识复用:
MLP承载了大量的模型知识和能力。冻结MLP权重使得我们无需重新训练整个模型,只需关注注意力机制的调整,从而极大地加速了探索过程并降低了计算资源需求。 -
注意力层稀疏性: 并非所有注意力层都对所有任务同等重要。通过智能地放置全注意力层 (full-attention layers) 或高效的线性注意力层 (linear attention layers),可以在保持关键性能的同时,最大化效率。
-
硬件效率考量: 模型效率不仅取决于理论计算复杂度,还严重依赖于实际硬件上的运行表现。因此,将硬件感知 (hardware-awareness) 融入架构搜索至关重要,尤其是关注
KV Cache大小对吞吐量的影响。PostNAS的路线图(如原文Figure 2所示)清晰地展示了这一原理的实施:从一个预训练的全注意力模型 (full-attention model) 开始,通过粗粒度到细粒度 (coarse-to-fine) 的搜索策略,逐步优化注意力块 (attention block) 的设计。
以下是原文 的图像:
该图像是图2示意图,展示了PostNAS的整体流程,包括基于预训练全注意力模型冻结MLP权重,进行全注意力层位置优化、线性注意力块选择、新注意力块设计以及硬件感知超参数搜索四个阶段,体现了Jet-Nemotron的高效设计路径。
4.2. 核心方法详解
PostNAS 流水线包含四个关键步骤,系统地引导我们进行高效的注意力块设计。
4.2.1. 全注意力层放置和消除 (Full Attention Placement and Elimination)
尽管线性注意力能够提高效率,但保留少量的全注意力层 (full-attention layers) 对于在检索等挑战性任务上维持高准确率至关重要 [30]。然而,这些层的最佳放置位置并不明确。传统的做法是均匀放置全注意力层,但这被认为是次优的。
PostNAS 提出了一种自动确定全注意力层位置的方法。其核心思想是构建一个一次性超网络 (once-for-all supernetwork)。
以下是原文 的图像:
该图像是图4示意图,展示了利用PostNAS训练一次性超网络并通过Beam Search搜索最优全注意力层位置的过程。
具体步骤:
- 构建超网络 (Supernetwork Construction): 通过在预训练的全注意力模型 (full-attention model) 中添加替代的线性注意力路径 (linear attention paths) 来构建一次性超网络 (once-for-all supernetwork) [45, 31]。这意味着模型中的每个注意力层都可以是全注意力或线性注意力。
- 超网络训练 (Supernetwork Training): 在训练过程中,模型会随机采样 (randomly sample) 激活路径(即,每个层选择全注意力或线性注意力),形成一个子网络。这个子网络使用蒸馏损失 (distillation loss) [46, 47, 48] 进行训练。
- 蒸馏损失 (Distillation Loss):这是一种训练技巧,用于将一个大型、高性能的教师模型(在这里,可以是原始的、完全的全注意力模型,或者是在训练过程中随机激活的全注意力子网络)的知识转移到一个较小、更高效的学生模型(在这里,是包含线性注意力路径的子网络)中。其基本思想是让学生模型的输出(例如,概率分布或特征表示)尽可能地接近教师模型的输出。
- 一次性超网络 (Once-for-All Supernetwork):指一个包含所有可能子网络的大型网络。通过一次性训练这个超网络,所有子网络的权重都得到了部分训练,从而避免了为每个子网络单独训练的巨大成本。
- 波束搜索 (Beam Search): 在超网络训练完成后,使用波束搜索 (beam search) [49] 算法来确定在给定约束(例如,只允许使用两个全注意力层)下,全注意力层的最优放置。
- 波束搜索 (Beam Search):这是一种启发式搜索算法,用于在大型搜索空间中找到近似最优解。它通过在每一步保留 个最佳候选( 为波束宽度),并从这些候选扩展出下一组候选,从而避免指数级的搜索空间。在这里,它用于探索不同全注意力层组合。
- 搜索目标 (Search Objective): 搜索目标是任务依赖 (task-dependent) 的。
-
对于
MMLU任务,目标是选择在正确答案上具有最低损失(即最大化_loss)的配置。 -
对于数学和检索任务,目标是选择准确率最高的配置。
关键发现:
-
-
非均匀贡献: 预训练的全注意力模型中,并非所有注意力层都同等重要。例如,对于
MMLU,只有两层至关重要;对于检索任务,两到三层是关键。 -
任务特定性: 不同的注意力层对不同的能力有贡献。对
MMLU重要的层不一定对检索任务重要。 -
复杂任务模式: 对于数学推理等复杂任务,注意力重要性模式更复杂。但幸运的是,
MMLU和检索任务识别出的关键层组合已经包含了数学任务所需的大部分关键层。 -
优于均匀放置:
PostNAS显著优于均匀放置策略,如Figure 5(b)所示。以下是原文 的图像:
该图像是图表,展示了Figure 5中Jet-Nemotron模型在Qwen2.5-1.5B上的层次放置搜索结果和PostNAS与均匀放置策略的性能比较。(a)部分展示了不同任务中注意力层的重要性热力图,(b)部分对比了不同全注意力层数量下的MMLU准确率,显示PostNAS明显优于Uniform方法。
4.2.2. 线性注意力块选择 (Linear Attention Block Selection)
在确定了全注意力层 (full-attention layer) 的放置后,下一步是选择最适合的线性注意力块 (linear attention block) 来填充剩余的层。得益于 PostNAS 框架的低训练成本,可以系统地评估现有的线性注意力块。
评估过程:
-
候选块: 评估了六种
SOTA线性注意力块,包括RWKV7[10]、RetNet[12]、Mamba2[50]、GLA[11]、Deltanet[51] 和Gated DeltaNet[32]。 -
效率分析: 首先进行效率分析,发现
RWKV7的训练吞吐量显著低于其他块,可能由于次优的核实现 (suboptimal kernel implementation),因此在训练实验中被排除。 -
性能评估: 在多项任务上评估剩余线性注意力块的准确率、训练效率和推理速度。
结果与选择:
Gated DeltaNet在评估的线性注意力块中实现了最佳的整体准确率。- 这归因于其结合了两个关键机制:
- 数据依赖门控机制 (Data-Dependent Gating Mechanism) [52]:动态控制模型是更关注当前词元还是历史状态。
- Delta 规则 (Delta Rule) [53]:通过当前词元的信息增量更新历史状态,以有效利用有限的状态内存。
- 因此,
Gated DeltaNet被选为后续实验的基础。
4.2.3. 新注意力块设计 (New Attention Block Design - JetBlock)
虽然 Gated DeltaNet 表现良好,但 PostNAS 框架进一步支持新注意力块 (new attention block) 的快速设计。本文提出了 JetBlock,旨在通过将动态卷积 (dynamic convolution) [54, 55] 融入线性注意力 (linear attention) 来增强模型的表达能力。
问题背景: 卷积已被证明对许多线性注意力块的准确率至关重要 [32, 56]。然而,以前的方法通常依赖静态卷积核 (static convolution kernels),这些核无法根据输入动态调整其特征提取模式 (feature extraction patterns)。
JetBlock 设计:
-
核生成器模块 (Kernel Generator Module):
JetBlock引入了一个核生成器模块 (kernel generator module),它根据输入特征动态生成卷积核 (convolution kernels)。- 输入: 该模块与 投影层共享相同的输入。
- 结构: 从一个线性降维层开始,使用
8倍的降维比以提高效率。之后应用GeLU[57] 激活函数 (activation function),最后是一个线性层,输出卷积核的权重。
-
动态卷积应用: 动态生成的卷积核被应用于值 (Value, V) 词元。研究发现,将它们应用于查询 (Query, Q) 或键 (Key, K) 词元几乎没有益处。
-
简化计算: 此外,研究发现,一旦动态卷积应用于 , 和 上的静态卷积 (static convolutions) 就可以被移除,对最终模型准确率的影响可以忽略不计。这一设计进一步简化了计算,略微提高了效率。
-
时间混合 (Time-Mixing):
JetBlock沿用了Gated DeltaNet的时间混合 (time-mixing) 机制,因为Gated DeltaNet在线性注意力块选择中表现最佳。JetBlock相比于之前的线性注意力块,在数学推理和检索任务上提供了更好的准确率,同时保持了相似的效率(如原文Table 1所示)。
4.2.4. 硬件感知架构搜索 (Hardware-Aware Architecture Search)
在确定了宏观架构 (macro architecture)(即全注意力层 (full-attention layers) 的放置)并选择了线性注意力块 (linear attention block) 之后,PostNAS 的最后一步是进行硬件感知架构搜索 (hardware-aware architecture search),以优化核心架构超参数,包括键/值维度 (key/value dimension) 和注意力头数 (number of attention heads)。
传统方法的问题: 传统上,参数量 (parameter size) 被用作指导模型架构设计的主要效率指标。然而,这种方法是次优的,因为参数量与实际硬件上的生成效率 (generation efficiency) 并不直接相关。
硬件感知搜索的改进: 本文通过直接使用生成吞吐量 (generation throughput) 作为选择架构超参数的目标来解决这一局限性。
关键发现 (Key Finding 4):
KV 缓存大小 (KV cache size) 是影响长上下文和长生成吞吐量的最关键因素。当 KV 缓存大小恒定时,具有不同参数量的模型表现出相似的生成吞吐量 (generation throughput)(如原文 Table 2 所示)。
原因: 解码阶段通常是内存带宽受限 (memory-bandwidth-bound),而不是计算受限 (compute-bound)。在长上下文场景中,KV 缓存消耗的内存往往超过模型权重。减小 KV 缓存大小可以减少每个解码步骤的内存传输时间,并允许更大的批次大小 (batch size),从而提高生成吞吐量 (generation throughput)。
优化策略:
-
固定
KV缓存大小: 根据Key Finding 4,将KV缓存大小固定为与原始设计匹配。 -
网格搜索 (Grid Search): 在固定的
KV缓存大小约束下,对键维度 (key dimension)、值维度 (value dimension) 和注意力头数 (number of attention heads) 进行小规模的网格搜索。结果:
- 通过这种硬件感知搜索,发现了能够提供与原始设计相似的生成吞吐量 (generation throughput) 的架构超参数组合。
- 这些新的配置通常能够容纳更多的参数,从而在不牺牲效率的前提下实现更好的准确率。
- 如原文
Table 1所示,这一步骤进一步提升了JetBlock的准确率,同时保持了训练和推理吞吐量。
5. 实验设置
5.1. 数据集
Jet-Nemotron 模型的训练使用了多阶段的数据混合策略,并在全面的基准测试套件上进行评估。
5.1.1. 训练语料
训练分为两个阶段:
- 第一阶段 (Stage 1):
- 目的: 用于
PostNAS过程,主要进行蒸馏训练 (distillation training)。 - 数据: 结合了
Nemotron-CC[63] 和Redstone-QA[64] 作为预训练语料。 - 规模: 模型在此阶段训练了 50B 词元。这也是
PostNAS过程(第2节讨论)进行时的设置。
- 目的: 用于
- 第二阶段 (Stage 2):
- 目的: 进行全模型训练 (full-model training)。
- 数据: 在第一阶段数据的基础上,加入了更多高质量的数学 [65] 和编码 [66, 67] 领域数据,以增强模型在这些专业领域的表现。
- 规模: 模型在此阶段训练了 350B 词元。
5.1.2. 评估基准
Jet-Nemotron 在以下六个主流基准类别中进行了全面评估:
- 大规模多任务语言理解 (MMLU(-Pro)) [18, 19]:
- 特点: 衡量模型在广泛领域(科学、人文、社会科学等)的知识和推理能力。
- MMLU-Pro 是
MMLU的更鲁棒和具挑战性的版本。
- 数学推理 (Mathematical Reasoning) [18, 20, 21, 22, 39]:
- 任务:
GSM8K(小学数学问题),MATH(高中数学竞赛题),MathQA(数学问答),GPQA(研究生水平的Google-Proof问答)。 - 特点: 评估模型理解和解决数学问题的能力。
- 任务:
- 常识推理 (Commonsense Reasoning) [33, 34, 35, 36, 37, 38]:
- 任务:
ARC-c(挑战性问答),ARC-e(简单问答),PIQA(物理常识),Wino.(指代消解),OBQA(开放域问答),BoolQ(布尔问答),TruthQA(事实性问答)。 - 特点: 评估模型在日常世界知识和常识方面的理解。
- 任务:
- 检索 (Retrieval) [23, 24, 25]:
- 任务:
FDA,SWDE,Squad。 - 特点: 评估模型从给定文本中提取相关信息的能力。
- 任务:
- 编码 (Coding) [26, 27, 28, 40]:
- 任务:
EvalPlus(代码生成评估),CRUXEval(代码推理、理解和执行)。 - 特点: 评估模型生成、理解和执行代码的能力。
- 任务:
- 长文本任务 (Long-Context Tasks) [29]:
- 任务:
LongBench(一个包含多种长文本理解任务的基准,上下文长度可达 64K)。 - 特点: 评估模型处理和理解长篇幅文本信息的能力,这是传统 Transformer 模型的痛点。
- 任务:
5.1.3. 评估策略
- 少样本评估 (Few-shot evaluation):
GSM8K和MATH采用 4-shot 评估。GPQA和MMLU-Pro采用 5-shot 评估。
- 零样本评估 (Zero-shot evaluation): 除上述任务外,所有其他任务均采用零样本设置。
- 编码任务: 使用
EvalPlus[40] 和CRUXEval[28] 的官方实现进行评估。 - 所有其他任务: 评估基于
LM-Evaluation-Harness[68] 进行。
5.2. 评估指标
论文使用了多种评估指标来衡量模型的性能和效率。
5.2.1. 准确率 (Accuracy)
- 概念定义: 准确率是衡量模型在分类、问答、推理等任务中预测正确程度的常用指标。它表示模型正确预测的样本数量占总样本数量的比例。较高的准确率意味着模型在该任务上表现越好。
- 数学公式:
- 符号解释:
Number of Correct Predictions: 模型做出的正确预测的数量。Total Number of Predictions: 模型做出的总预测数量(即评估数据集中的样本总数)。
5.2.2. 生成吞吐量 (Generation Throughput)
- 概念定义: 生成吞吐量衡量模型在推理阶段的效率,具体指模型每秒能够生成的词元(tokens)数量。在
LLM部署中,这是一个关键指标,直接影响用户体验和运营成本。 - 数学公式:
- 符号解释:
Number of Generated Tokens: 模型在一定时间内生成的词元总数。Time Taken for Generation (seconds): 生成这些词元所花费的时间(以秒为单位)。
5.2.3. 预填充加速 (Prefilling Speedup)
- 概念定义: 预填充加速是指在处理用户输入的提示(prompt)阶段,目标模型相对于基线模型的速度提升倍数。在生成式
LLM中,预填充是模型接收输入并准备生成输出的第一个阶段。 - 数学公式: 或等价地
- 符号解释:
Baseline Model Prefilling Throughput: 基线模型在预填充阶段的吞吐量。Jet-Nemotron Prefilling Throughput:Jet-Nemotron模型在预填充阶段的吞吐量。Jet-Nemotron Prefilling Time per Token:Jet-Nemotron模型预填充每个词元所需的时间。Baseline Model Prefilling Time per Token: 基线模型预填充每个词元所需的时间。
5.2.4. 解码加速 (Decoding Speedup)
- 概念定义: 解码加速是指在模型逐词生成输出(即解码)阶段,目标模型相对于基线模型的速度提升倍数。这是自回归生成中的核心效率指标。
- 数学公式: 或等价地
- 符号解释:
Baseline Model Decoding Throughput: 基线模型在解码阶段的吞吐量。Jet-Nemotron Decoding Throughput:Jet-Nemotron模型在解码阶段的吞吐量。Jet-Nemotron Decoding Time per Token:Jet-Nemotron模型解码每个词元所需的时间。Baseline Model Decoding Time per Token: 基线模型解码每个词元所需的时间。
5.3. 对比基线
论文将 Jet-Nemotron 与当前领域中代表性的全注意力模型 (full-attention models)、线性注意力模型 (linear attention models) 和混合模型 (hybrid models) 进行了比较,以全面评估其性能和效率。
-
全注意力模型 ():
Qwen2.5-1.5B[4]Qwen3-1.7B-Base[5]Llama3.2-3B[2]MiniCPM-2B-128K[58]MobileLLM-1.5B[59]Smollm2-1.7B[60]DeepSeek-V3-Small@1.3T[6] (混合专家模型,激活参数 2.2B / 总参数 15B)Moonlight@1.2T[61] (混合专家模型,激活参数 2.2B / 总参数 15B)Mamba2-2.7B[50] (虽然Mamba2是一种状态空间模型,在某些比较表格中被归类为 ,但在Table 3中被列为 类别下的基线,这可能是为了将其与纯粹的线性注意力模型区分开来,强调其在某些方面的能力可与全注意力模型匹敌,但并非本文所指的混合架构)。
-
线性注意力模型 ():
RWKV7-1.5B[10]Rec.Gemma-2B[62]Gemma3n-E2B[42] (此模型在某些场合也被视为混合模型,但在此处被归类到 类别下进行比较)
-
混合模型 (Hybrid):
Hymba-1.5B[44]Zamba2-1.2B[16]
5.4. 硬件配置与软件环境
为了确保公平和一致的比较,论文详细说明了实验所用的硬件和软件环境。
-
硬件:
- 服务器:
DGX H100服务器。 - GPU: 8块
NVIDIA H100 GPU。 - CPU: 2个
Intel Xeon Platinum 8480C(112 核)。 - 内存: 2TB
RAM。 - 单个 GPU 测试: 除特别说明外,所有模型都在单个
H100 GPU上进行测试。
- 服务器:
-
软件:
- 深度学习框架:
Pytorch 2.7.0。 - 高性能计算库:
Triton 3.3.0。 - 注意力机制实现:
- 全注意力块 (full-attention block): 使用
FlashAttention 2.7.4[69] 进行实现,这是一种优化后的自注意力实现。 - 线性注意力块 (linear-attention blocks): 使用
Flash-Linear-Attention 0.2.1[70] 进行实现,这是一种优化后的线性注意力实现。
- 全注意力块 (full-attention block): 使用
- 模型推理: 基于
Transformers 4.52.0[71] 的实现。
- 深度学习框架:
-
吞吐量测试细节:
- 上下文长度 (Context length): 默认 64K 词元,除非明确指出。
- 缓存大小 (Cache sizes):
Table 3中报告了 64K 输入上下文下的缓存大小。 - 分块预填充 (Chunk-prefilling): 在测试吞吐量时,采用了
分块预填充 (chunk-prefilling)[72] 技术。 - 批次大小优化: 通过搜索分块大小 (chunk size),在
GPU内存约束下最大化解码批次大小 (decoding batch size),以测量设备上可实现的最高解码吞吐量。Table 13列出了每个模型使用的优化批次大小和对应的分块大小。
5.5. Jet-Nemotron 模型家族配置
Jet-Nemotron 家族包含两个主要版本:Jet-Nemotron-2B 和 Jet-Nemotron-4B,它们具有不同的参数规模和架构配置。模型的每个块都包含一个多层感知器 (MLP) 层和一个注意力层 (attention layer),注意力层可以是全注意力 (full attention)、滑动窗口注意力 (sliding window attention, SWA) 或 JetBlock。
以下是原文 的表格:
| Jet-Nemotron-2B | Jet-Nemotron-4B | |
| Total blocks | 28 | 36 |
| Full Attention Layers | No. 15, 20 | No. 18, 21, 22, 28, 33 |
| Sliding Window Attention Layers | No. 21, 22 | No. 17, 20, 23, 24, 26 |
| Vocabulary Size | 151,643 | 151,643 |
| Hidden Size | 1,536 | 2,048 |
| MLP Intermediate Size | 8,960 | 11,008 |
5.5.1. Jet-Nemotron-2B
- 总块数: 28 个。
- 基座模型: 基于
Qwen2.5-1.5B。 - 全注意力层: No. 15 和 20 层,共 2 层。这些层主要用于指导检索任务 (retrieval tasks) 的性能。
- 滑动窗口注意力 (SWA) 层: No. 21 和 22 层,共 2 层。
SWA主要用于多项选择任务 (multiple-choice tasks),如MMLU,因为它能有效保留这些任务的准确率。 - 其他注意力层: 替换为
JetBlock。 - 词汇表大小 (Vocabulary Size): 151,643。
- 隐藏层大小 (Hidden Size): 1,536。
- MLP 中间层大小 (MLP Intermediate Size): 8,960。
5.5.2. Jet-Nemotron-4B
- 总块数: 36 个。
- 基座模型: 基于
Qwen2.5-3B。 - 全注意力层: No. 18, 21, 22, 28, 33 层,共 5 层。
- 滑动窗口注意力 (SWA) 层: No. 6, 17, 20, 22, 23, 26, 28 层,共 7 层。
- 其他注意力层: 替换为
JetBlock。 - 词汇表大小 (Vocabulary Size): 151,643。
- 隐藏层大小 (Hidden Size): 2,048。
- MLP 中间层大小 (MLP Intermediate Size): 11,008。
5.5.3. 全注意力/SWA 层配置
-
全注意力层和
SWA层均采用分组查询注意力 (grouped-query attention) [105] 机制。 -
位置编码 (Position Embedding): 均使用旋转位置嵌入 (Rotary Position Embedding, RoPE)。
以下是原文 的表格:
Full Attention / SWA Jet-Nemotron-2B Jet-Nemotron-4B Attention Head Number 12 16 Dimensions of Q/K/V 128 128 K/V Head Number 2 2 Position Embedding RoPE RoPE
5.5.4. JetBlock 配置
-
JetBlock的具体配置,包括Q/K维度、 维度、头数、卷积核大小和动态卷积生成器隐藏层大小。以下是原文 的表格:
JetBlock Jet-Nemotron-2B Jet-Nemotron-4B Q/K Dimension 96 128 V Dimension 256 256 Head Number 12 16 Convolution Kernel Size 4 4 DConv Generator Hidden Size 32 32
5.6. 实验成本
以下是原文 的表格,总结了 PostNAS 和训练 Jet-Nemotron-2B 模型所花费的成本。
| Tokens (B) | ZFLOPs | Time (H100 GPU Hours) | ||
| PostNAS | Full Attention Placement and Elimination | 50 | 0.8 | 808 |
| Linear Attention Block Selection | 50 | 4.0 | 3120 | |
| New Attention Block Design | 50 | 0.8 | 624 | |
| Harware-Aware Arch Search | 50 | 7.2 | 5616 | |
| Training | Stage1 | 50 | 0.8 | 624 |
| Stage2 | 350 | 5.6 | 7536 |
实验使用了 32 块 H100 GPU 并行计算,报告的 GPU 小时数已计入设备总数。可以看出,PostNAS 各阶段的探索成本相对较低,总计约为 10168 H100 GPU 小时。后续的训练阶段(Stage 1 和 Stage 2)则占用了更多的资源,总计 8160 H100 GPU 小时。
6. 实验结果与分析
本节详细介绍了 Jet-Nemotron 模型家族在各种基准测试中的实验结果,并对其性能和效率进行了深入分析。
6.1. 核心结果分析
6.1.1. MMLU(-Pro) 和 BBH 结果
以下是原文 的表格,展示了 Jet-Nemotron 与最先进的 (state-of-the-art) 高效语言模型在 MMLU、MMLU-Pro 和 BBH 上的表现。
| Type | Model | Params (B) | Cache Size (MB) | Throughput (token/s) ↑ | MMLU Acc. ↑ | MMLU-Pro Acc. ↑ | BBH Acc. ↑ |
| O(n2) | Qwen2.5-1.5B [4] | 1.5 | 1,792 | 241 | 59.5 | 28.9 | 44.1 |
| Qwen3-1.7B-Base [5] | 1.7 | 7,168 | 61 | 60.3 | 37.8 | 54.2 | |
| Llama3.2-3B [2] | 3.0 | 7,168 | 60 | 54.9 | 25.0 | 47.1 | |
| MiniCPM-2B-128K [58] | 2.8 | 23,040 | 18 | 46.0 | 18.0 | 36.5 | |
| MobileLLM-1.5B [59] | 1.5 | 4,320 | 101 | 26.0 | 9.4 | 27.2 | |
| Smollm2-1.7B [60] | 1.7 | 12,288 | 32 | 48.5 | 18.3 | 35.1 | |
| DeepSeek-V3-Small@1.3T [6] | 2.2/15 | - | , | 53.3 | - | - | |
| Moonlight@1.2T [61] | 2.2/15 | - | - | 60.4 | 28.1 | 43.2 | |
| Mamba2-2.7B [50] | 2.7 | 80 | 2,507 | 25.1 | 8.6 | 25.7 | |
| O(n) | RWKV7-1.5B [10] | 1.5 | 24 | 3,050 | 41.0 | 13.4 | 15.9 |
| Rec.Gemma-2B [62] | 2.0 | 16 | 2,355 | 28.6 | 12.8 | 33.3 | |
| Gemma3n-E2B [42] | 2.0 | 768 | 701 | 53.9 | 24.3 | ||
| Hymba-1.5B [44] | 1.5 | 240 | 180 | 49.7 | 17.4 | 45.1 29.8 | |
| Zamba2-1.2B [16] | 1.2 | 6,114 | 71 | 43.1 | 14.2 | 19.6 | |
| Hybrid | Jet-Nemotron-2B | 2.0 | 154 | 2,885 | 60.8 | 39.0 | 58.3 |
| Jet-Nemotron-4B | 4.0 | 258 | 1,271 | 65.2 | 44.2 | 65.0 |
Jet-Nemotron-2B在MMLU(60.8)、MMLU-Pro(39.0) 和BBH(58.3) 上的准确率,均高于Qwen3-1.7B-Base(MMLU: 60.3,MMLU-Pro: 37.8,BBH: 54.2)。- 同时,
Jet-Nemotron-2B的生成吞吐量 (generation throughput) 达到 2,885token/s,是Qwen3-1.7B-Base(61token/s) 的约 47 倍。其KV缓存大小也显著更小 (154MB vs 7,168MB)。 Jet-Nemotron-2B甚至优于一些MoE模型,如DeepSeek-V3-Small和Moonlight,尽管这些模型有更大的激活参数 (activated parameters) (2.2B) 和总参数 (total parameters) (15B)。- 当模型规模扩大到
Jet-Nemotron-4B时,其准确率进一步提升,同时仍保持了相对于Qwen3-1.7B-Base约21倍的吞吐量优势 (1,271token/svs 61token/s)。 - 与其他的线性注意力模型和混合模型相比,
Jet-Nemotron也取得了显著更高的准确率。
6.1.2. 数学任务结果
以下是原文 的表格,展示了模型在数学任务上的性能。
| Type Model | | Throughput| | (token/s) ↑ | | Accuracy ↑ | ||||||
| Avg. | |GSM8K MATH MathQA MMLU-Stem | GPQA | ||||||
| O(n2) | Qwen2.5-1.5B [4] | 241 | 38.4 | 62.4 | 13.1 | 34.4 | 52.7 | 29.4 |
| Qwen3-1.7B-Base [5] | 61 | 42.3 | 62.8 | 16.7 | 46.0 | 50.8 | 27.9 | |
| Llama3.2-3B [2] | 60 | 28.8 | 25.8 | 8.6 | 34.2 | 45.3 | 30.1 | |
| MiniCPM-2B-128K [58] | 18 | 27.6 | 39.2 | 5.9 | 28.5 | 36.3 | 28.1 | |
| Smollm2-1.7B [60] | 32 | 28.9 | 30.3 | 9.2 | 33.7 | 41.3 | 30.1 | |
| O(n) | Mamba2-2.7B [50] | 2,507 | 16.6 | 3.0 | 3.9 | 24.3 | 26.6 | 25.3 |
| RWKV7-1.5B [10] | 2,669 | 18.3 | 5.6 | 0.8 | 27.2 | 34.9 | 23.0 | |
| Rec.Gemma-2B [62] | 2,355 | 20.8 | 13.9 | 7.6 | 25.3 | 28.5 | 28.6 | |
| Gemma3n-E2B [42] | 701 | 28.3 | 24.9 | 10.1 | 31.1 | 45.7 | 31.8 | |
| Hymba-1.5B [44] | 180 | 23.1 | 17.9 | 0.8 | 28.0 | 40.9 | 27.9 | |
| Hybrid Zamba2-1.2B [16] | 71 | 24.8 | 28.1 | 5.9 | 26.0 | 36.5 | 27.7 | |
| Jet-Nemotron-2B | 2,885 | 49.6 | 76.2 | 23.3 | 53.8 | 62.7 | 32.1 | |
| Jet-Nemotron-4B | 1,271 | 51.3 | 78.7 | 25.2 | 52.5 | 65.6 | 34.6 | |
Jet-Nemotron-2B在数学任务上取得了 49.6 的平均准确率,显著高于Qwen3-1.7B-Base(42.3),同时速度快47倍。- 在具体子任务上,
Jet-Nemotron-2B在GSM8K和MathQA上表现尤其出色,远超所有基线。 - 这表明
Jet-Nemotron成功地解决了现有线性注意力模型和混合模型在数学任务上准确率不足的问题。
6.1.3. 常识推理任务结果
以下是原文 的表格,总结了模型在常识推理任务上的表现。
| Model | Throughput | Accuracy ↑ | |||||||
| |(token/s) ↑ | | | Avg. | | ARC-c ARC-e PIQA Wino. | OBQA BoolQ TruthQA | ||||||
| Qwen2.5-1.5B [4] | 241 | 59.4 | 45.4 | 71.2 | 75.8 | 63.8 | 40.2 | 72.8 | 46.6 |
| Qwen3-1.7B-Base [5] | 61 | 60.0 | 44.9 | 68.6 | 75.5 | 63.8 | 39.0 | 79.0 | 48.8 |
| Llama3.2-3B [2] | 60 | 59.9 | 46.6 | 72.0 | 78.0 | 69.3 | 40.4 | 73.9 | 39.3 |
| MiniCPM-2B-128K [58] | 18 | 57.6 | 41.0 | 69.4 | 75.5 | 63.8 | 40.6 | 74.7 | 38.3 |
| Smollm2-1.7B [60] | 32 | 59.7 | 47.0 | 73.3 | 77.7 | 66.2 | 44.6 | 72.5 | 36.7 |
| Mamba2-2.7B [50] | 2,507 | 57.2 | 42.1 | 70.5 | 76.1 | 62.7 | 41.4 | 71.5 | 36.1 |
| RWKV7-1.5B [10] | 3,050 | 59.7 | 46.3 | 75.7 | 77.4 | 67.6 | 45.4 | 70.5 | 34.7 |
| Rec.Gemma-2B [62] | 2,355 | 46.5 | 29.4 | 41.5 | 66.6 | 54.1 | 27.0 | 72.0 | 34.7 |
| Gemma3n-E2B [42] | 701 | 58.6 | 43.2 | 73.1 | 77.0 | 60.8 | 40.8 | 76.0 | 39.1 |
| Hymba-1.5B [44] | 180 | 61.2 | 46.9 | 76.9 | 77.7 | 66.2 | 41.0 | 80.8 | 39.0 |
| Zamba2-1.2B [16] | 71 | 58.0 | 44.4 | 66.8 | 77.4 | 65.6 | 42.8 | 70.8 | 38.5 |
| Jet-Nemotron-2B | 2,885 | 62.0 | 48.6 | 74.8 | 75.4 | 65.8 | 40.6 | 81.2 | 47.8 |
| Jet-Nemotron-4B | 1,271 | 64.7 | 51.7 | 79.2 | 78.1 | 70.5 | 43.6 | 83.0 | 46.6 |
Jet-Nemotron-2B取得了 62.0 的平均准确率,超越了所有基线模型,包括Qwen2.5和Qwen3,尽管Qwen系列在该领域相对较弱。Jet-Nemotron-4B进一步提升至 64.7 的平均准确率,在ARC-c、ARC-e、PIQA和BoolQ等任务上表现突出。
6.1.4. 检索任务结果
以下是原文 的表格,展示了模型在检索任务上的性能。
| Type | Model | Throughput (token/s) ↑ | Accuracy ↑ | |||
| Avg. | FDA | SWDE | Squad | |||
| O(n2) | Qwen2.5-1.5B [4] | 241 | 72.4 | 82.8 | 86.3 | 48.1 |
| Qwen3-1.7B-Base [5] | 61 | 76.1 | 81.8 | 89.2 | 57.2 | |
| Llama3.2-3B [2] | 60 | 71.3 | 82.3 | 89.6 | 56.4 | |
| MiniCPM-2B-128K [58] | 18 | 72.6 | 72.3 | 86.4 | 59.1 | |
| Smollm2-1.7B [60] | 32 | 68.9 | 78.1 | 82.4 | 46.3 | |
| O(n) | Mamba2-2.7B [50] | 2,507 | 57.0 | 51.7 | 74.3 | 45.1 |
| RWKV7-1.5B [10] | 3,050 | 58.6 | 54.5 | 73.3 | 48.0 | |
| Rec.Gemma-2.6B [62] | 2,355 | 68.8 | 62.3 | 86.4 | 57.8 | |
| Hybrid | Gemma3n-E2B [73] | 701 | 74.0 | 77.3 | 86.4 | 58.2 |
| Hymba-1.5B [44] | 180 | 57.1 | 46.6 | 74.4 | 50.2 | |
| Zamba2-1.2B [16] | 71 | 66.4 | 73.8 | 80.7 | 44.8 | |
| Jet-Nemotron-2B | 2,885 | 74.2 | 80.4 | 85.7 | 56.6 | |
| Jet-Nemotron-4B | 1,271 | 76.2 | 82.5 | 89.7 | 56.4 | |
Jet-Nemotron-2B取得了 74.2 的平均准确率,优于所有基线模型,除了Qwen3-1.7B-Base(76.1)。- 当扩展到
Jet-Nemotron-4B时,其平均准确率达到 76.2,超越了Qwen3-1.7B-Base,实现了最佳性能,同时仍保持了约21倍的加速。 - 这表明
PostNAS成功地在Jet-Nemotron中保留了全注意力层 (full-attention layers) 对于检索任务的强大能力。
6.1.5. 编码任务结果
以下是原文 的表格,展示了模型在编码任务上的性能。
| Type | Model | | Throughput (token/s) ↑ | Accuracy ↑ | |||
| Avg. | EvalPlus | CRUXEval-I-cot | CRUXEval-O-cot | |||
| O(n2) | Qwen2.5-1.5B [4] | 241 | 52.0 | 54.3 | 56.0 | 45.8 |
| Qwen3-1.7B-Base [5] | 61 | 58.9 | 62.8 | 60.4 | 53.4 | |
| Llama3.2-3B [2] | 60 | 44.0 | 35.5 | 54.7 | 41.7 | |
| MiniCPM-2B-128K [58] | 18 | 34.2 | 40.7 | 29.9 | 31.9 | |
| Smollm2-1.7B[ [60] | 32 | 36.2 | 20.6 | 49.5 | 38.6 | |
| O(n) | Mamba2-2.7B [50] | 2,507 | 14.0 | 12.0 | 9.3 | 20.7 |
| RWKV7-1.5B [10] | 3,050 | 13.2 | 16.8 | 8.0 | 14.7 | |
| Rec.Gemma-2.6B [62] | 2,355 | 36.8 | 29.5 | 46.7 | 34.2 | |
| Hybrid | Gemma3n-E2B [73] | 701 | 40.4 | 29.6 | 49.9 | 41.6 |
| Hymba-1.5B [44] | 180 | 30.3 | 31.3 | 32.2 | 27.5 | |
| Zamba2-1.2B [16] | 71 | 20.1 | 12.7 | 21.1 | 26.4 | |
| Jet-Nemotron-2B | 2,885 | 59.5 | 60.8 | 61.1 | 56.7 | |
| Jet-Nemotron-4B | 1,271 | 63.5 | 65.6 | 65.9 | 59.0 | |
Jet-Nemotron-2B取得了 59.5 的平均准确率,与Qwen3-1.7B-Base(58.9) 相当。Jet-Nemotron-4B进一步提升至 63.5 的平均准确率,在所有编码任务上均实现更高准确率,同时仍保持了相对于Qwen3-1.7B-Base的吞吐量优势。
6.1.6. 长文本任务结果
以下是原文 的表格,展示了模型在 LongBench 长文本任务(上下文长度达 64K)上的性能。
| Type | Model | Throughput (token/s) ↑ | Accuracy ↑ | |||||
| Avg. | Few-Shot | Code | Sum. | Single-Doc | Multi-Doc | |||
| O(n2) | Qwen2.5-1.5B [4] | 241 | 39.1 | 63.9 | 57.2 | 26.3 | 28.3 | 19.9 |
| Qwen3-1.7B-Base [5] | 61 | 42.2 | 68.8 | 48.1 | 26.8 | 36.6 | 30.6 | |
| Llama3.2-3B [2] | 60 | 39.9 | 65.2 | 58.0 | 24.3 | 27.6 | 24.6 | |
| MiniCPM-2B-128K [58] | 18 | 41.1 | 57.3 | 59.6 | 25.7 | 33.4 | 29.6 | |
| Smollm2-1.7B [60] | 32 | 21.3 | 38.9 | 28.6 | 16.0 | 13.2 | 9.8 | |
| O(n) | Mamba2-2.7B [50] | 2,507 | 10.3 | 6.4 | 30.2 | 9.1 | 3.5 | 2.5 |
| RWKV7-1.5B [10] | 3,050 | 14.2 | 10.6 | 21.1 | 18.1 | 12.8 | 8.7 | |
| Rec.Gemma-2.6B [62] | 2,355 | 24.1 | 31.8 | 56.7 | 12.9 | 9.2 | 9.6 | |
| Hybrid | Gemma2-2.6B [73] | 388 | 22.9 | 28.7 | 52.0 | 12.6 | 13.9 | 7.3 |
| Gemma3n-E2B [73] | 701 | 40.4 | 56.4 | 67.2 | 25.6 | 29.3 | 28.6 | |
| Hymba-1.5B [44]. | 180 | 28.0 | 36.1 | 53.5 | 51.8 | 14.0 | 19.8 | |
| Zamba2-1.2B [16] | 71 | 9.2 | 10.0 | 20.1 | 10.2 | 3.8 | 1.7 | |
| Jet-Nemotron-2B | 2,885 | 41.1 | 68.7 | 58.1 | 26.0 | 30.8 | 21.9 | |
| Jet-Nemotron-4B | 1,271 | 43.9 | 69.7 | 63.2 | 26.4 | 32.5 | 27.5 | |
Jet-Nemotron-2B取得了 41.1 的平均准确率,与Qwen2.5-1.5B(39.1) 和Gemma3n-E2B(40.4) 等领先模型相当,尽管其全注意力层 (full-attention layers) 数量显著更少。Jet-Nemotron-4B表现更优,达到 43.9 的平均准确率,超越了Qwen3-1.7B-Base(42.2),同时实现了21倍的生成吞吐量加速。- 这些结果表明,
Jet-Nemotron在长文本任务上的效率-准确率权衡方面取得了显著进展。
6.1.7. 总结
Jet-Nemotron-2B 和 Jet-Nemotron-4B 在所有六个评估领域中,均能与先进的全注意力模型(如 Qwen3-1.7B-Base)相媲美甚至超越。同时,由于显著减少了全注意力层和键值缓存 (KV cache) 大小,Jet-Nemotron-2B 和 Jet-Nemotron-4B 分别实现了 47 倍和 21 倍于 Qwen3-1.7B-Base 的生成吞吐量 (generation throughput)。
6.2. 效率基准结果
本节详细比较了 Jet-Nemotron-2B 和 Qwen3-1.7B-Base 在不同上下文长度下的效率表现。
以下是原文 的图像:
该图像是图表,展示了不同上下文长度下Jet-Nemotron-2B与Qwen3-1.7B在预填充(prefilling)和解码(decoding)速度上的相对提升。Jet-Nemotron-2B的预填充速度最高达到,解码速度最高达到,显著优于Qwen3-1.7B。
- 预填充阶段 (Prefilling Stage):
- 在较短的上下文长度(4K 和 8K)下,
Jet-Nemotron-2B的预填充速度分别比Qwen3-1.7B-Base快1.14倍和1.15倍。作者指出,这一差距可以通过优化JetBlock的核实现 (kernel implementation) 进一步改善。 - 随着上下文长度的增加,线性注意力 (linear attention) 的优势变得更加突出。在 256K 的上下文长度下,
Jet-Nemotron-2B实现了6.14倍的预填充加速。
- 在较短的上下文长度(4K 和 8K)下,
- 解码阶段 (Decoding Stage):
-
Jet-Nemotron-2B在解码阶段始终大幅优于Qwen3-1.7B-Base。 -
Jet-Nemotron-2B包含 2 个全注意力层 (full-attention layers),每个层有 2 组键值状态 (key-value states)。相比之下,Qwen3-1.7B-Base有 28 个全注意力层,每个层有 8 组键值状态。这意味着理论上,Jet-Nemotron-2B的最大加速可以达到 倍。 -
在实际吞吐量测试中,
Jet-Nemotron-2B在 4K 上下文长度下实现了15.6倍的加速,在 256K 上下文长度下甚至达到了53.6倍的加速,几乎达到了理论上限。这些结果强有力地证明了
Jet-Nemotron在长上下文场景下的卓越效率。
-
6.3. 消融实验/参数分析
6.3.1. PostNAS 准确率提升分解
以下是原文 的图像:
该图像是图表,展示了图3中PostNAS对基线模型的准确度提升细分。通过逐步应用不同优化策略,在四个指标上均取得显著提升,最高达到58.1、34.9、70.4和59.3的准确率。
Figure 3 展示了通过逐步应用 PostNAS 的不同组件,Jet-Nemotron 相对于基线模型所实现的准确率提升分解。
- 基础模型: 初始的基线模型。
- PostNAS 步骤:
- 全注意力层放置和消除 (Full Attention Placement and Elimination): 带来了显著的准确率提升,尤其是在
MMLU和数学任务上。 - 线性注意力块选择 (Linear Attention Block Selection): 进一步提升了所有基准的准确率。
- 新注意力块设计 (New Attention Block Design) (
JetBlock): 再次带来了全面的准确率提升。 - 硬件感知架构搜索 (Hardware-Aware Architecture Search): 在保持效率的同时,通过优化超参数,进一步微调和提升了最终的准确率。
- 全注意力层放置和消除 (Full Attention Placement and Elimination): 带来了显著的准确率提升,尤其是在
- 结果:
PostNAS的每个阶段都对最终模型的准确率贡献了显著的提升,最终在MMLU上提高了 5.3,数学任务提高了 8.4,检索任务提高了 7.8,常识推理提高了 3.2。
6.3.2. JetBlock 准确率和效率
以下是原文 的表格,展示了 JetBlock 与其他线性注意力块的准确率和效率比较。
| Attention Block | Data-Depend Delta Gating | Rule | Throughput ↑ Training Inference | Accuracy ↑ MMLU Math Retreival | ||||
| Common. | ||||||||
| RWKV7 [10] | √ | ✓ | 123 | 2,542 | − | − | − | − |
| RetNet [12] | 269 | 2,535 | 53.6 | 29.9 | 63.7 | 58.1 | ||
| Mamba2 [50] | 273 | 3,220 | 51.5 | 26.0 | 68.9 | 57.5 | ||
| GLA [11] | ✓ | 265 | 3,079 | 55.8 | 31.2 | 66.6 | 58.5 | |
| Deltanet [51] | ✓ | 254 | 2,955 | 48.9 | 27.4 | 67.9 | 56.6 | |
| Gated DeltaNet [32] | ✓ | ✓ | 247 | 2,980 | 55.6 | 32.3 | 69.3 | 58.7 |
| JetBlock | ; | : | 233 | 2,885 | 56.3 | 32.8 | 69.9 | 58.5 |
| + Hardware-Aware Search | 227 | 2,883 | 58.1 | 34.9 | 70.4 | 59.5 | ||
JetBlock在数学推理 (Math) (32.8) 和检索 (Retrieval) (69.9) 任务上的准确率优于Gated DeltaNet(Math: 32.3,Retrieval: 69.3),同时保持了相似的训练和推理吞吐量。- 通过结合硬件感知架构搜索 (Hardware-Aware Architecture Search),
JetBlock的准确率在所有任务上都得到了进一步提升,例如MMLU达到 58.1,Math达到 34.9,Retrieval达到 70.4,Commonsense达到 59.5。这表明JetBlock的设计既能提升性能,又能维持高效率。
6.3.3. 硬件感知架构搜索细节
以下是原文 的表格,展示了硬件感知架构搜索的详细结果。
| dK | dV | nhead | Params (B) | Cache Size (MB) | Throughput ↑ (token/s) | Retrieval ↑ Accuracy | Math Accuracy ↑ |
| 256 | 288 | 4 | 1.62 | 154 | 2,969 | 67.6 | 31.3 |
| 192 | 384 | 4 | 1.64 | 154 | 2,961 | 69.3 | 32.3 |
| 128 | 576 | 4 | 1.70 | 154 | 2,979 | 69.5 | 32.5 |
| 256 | 144 | 8 | 1.66 | 154 | 2,986 | 68.3 | 32.1 |
| 192 | 192 | 8 | 1.70 | 154 | 2,970 | 70.6 | 32.8 |
| 128 | 288 | 8 | 1.74 | 154 | 2,971 | 69.6 | 33.2 |
| 128 | 192 | 12 | 1.78 | 154 | 2,959 | 68.8 | 32.9 |
| 96 | 256 | 12 | 1.84 | 154 | 2,955 | 69.6 | 34.8 |
| 64 | 384 | 12 | 1.98 | 154 | 2,952 | 70.1 | 34.2 |
- 表格展示了在固定
KV缓存大小 (154MB) 的前提下,不同键维度 (dK)、值维度 (dV) 和注意力头数 (nhead) 组合对模型参数量、吞吐量和准确率的影响。 - 灰色的行代表原始设计,而蓝色的行代表通过硬件感知搜索得到的新设计。
- 关键发现 4 (Key Finding 4) 得到验证:当
KV缓存大小恒定(都为 154MB)时,尽管模型参数量从 1.62B 增加到 1.98B,其生成吞吐量 (generation throughput) 仍然保持在相似的水平(约 2,950-2,980token/s)。 - 新的设计(例如 , , )允许在保持高吞吐量 (2,955
token/s) 的同时,使用更多的参数 (1.84B) 并显著提升了检索 (Retrieval) (69.6) 和数学 (Math) (34.8) 任务的准确率,这进一步验证了硬件感知架构搜索的有效性。
6.3.4. 训练数据控制研究
以下是原文 的表格,展示了在相同训练数据下,Jet-Nemotron-2B 与基线模型的对比。
| Model | MMLU | Math | Commonsense | Retrieval |
| Qwem2.5-1.5B-continual | 56.7 | 37.6 | 59.8 | 71.5 |
| Mamba2-2.7B-continual | 41.0 | 22.5 | 56.9 | 55.9 |
| RWKV7-1.5B-continual | 49.8 | 25.2 | 59.3 | 57.2 |
| Jet-Nemotron-2B | 59.6 | 40.2 | 61.7 | 73.6 |
为了排除训练数据差异对结果的影响,作者对 Qwen2.5、Mamba2 和 RWKV7 等基线模型进行了持续预训练 (continual pre-train),使其在与 Jet-Nemotron 相同的 Stage-2 训练语料上进行训练。
- 结果显示,即使在相同的训练数据下,
Jet-Nemotron-2B在MMLU(59.6)、Math(40.2)、Commonsense(61.7) 和Retrieval(73.6) 四个基准测试上,仍以显著优势超越了所有这些微调的基线模型 (finetuned baseline models)。 - 这强有力地证明了
Jet-Nemotron架构本身的优越性,而非仅仅是训练数据或训练策略的优势。
6.3.5. 低端硬件吞吐量结果
以下是原文 Table 15 | Throughput Results on Jetson Orin (32GB) and NVIDIA RTX 3090 GPUs 的表格,展示了 Jet-Nemotron-2B 在低端硬件上的吞吐量表现。
| Hardware | Qwen2.5-1.5B (Tokens/s) Jet-Nemotron-2B (Tokens/s) SpeedUp | ||
| Orin | 6.22 | 55.00 | 8.84 |
| 3090 | 105.18 | 684.01 | 6.50 |
为了验证 Jet-Nemotron 的泛化效率,作者在 NVIDIA Jetson Orin (32GB) 和 NVIDIA RTX 3090 GPU 上测试了 Jet-Nemotron-2B 和 Qwen2.5-1.5B 的吞吐量(上下文长度 64K)。
- 在
Jetson Orin上,Jet-Nemotron-2B实现了8.84倍的加速。 - 在
RTX 3090上,Jet-Nemotron-2B实现了6.50倍的加速。 - 这些结果表明,
Jet-Nemotron的效率优势不仅限于高端H100 GPU,在低端或消费级硬件 (lower-end or consumer-grade hardware) 上也能带来显著的性能提升。
6.3.6. 与 Falcon-H1 的比较
以下是原文 的表格,展示了 Jet-Nemotron 与同期工作 Falcon-H1 的比较。
| Model | |Throughput (token/s)↑ | Accuracy ↑ | |||||
| MMLU | MATH | Common. | Retrieval | Code | Long-Context | ||
| Falcon-H1-1.5B [106] | 223 | 60.5 | 40.1 | 59.9 | 73.5 | 56.0 | 40.7 |
| Falcon-H1-1.5B-deep [106] | 66 | 63.5 | 46.8 | 60.6 | 74.6 | 60.3 | 33.4 |
| Jet-Nemotron-2B | 2,885 | 60.8 | 49.6 | 62.0 | 74.2 | 59.5 | 41.1 |
| Jet-Nemotron-4B | 1,271 | 65.2 | 51.3 | 64.7 | 76.2 | 63.5 | 43.9 |
Falcon-H1 [106] 是一个同期提出的混合模型,它结合了 Mamba2 [50] 和全注意力机制。但与 Jet-Nemotron 在层级 (layer level) 切换组件不同,Falcon-H1 采用头级混合策略 (head-wise hybrid strategy)。
- 性能对比:
Jet-Nemotron-2B在准确率上优于Falcon-H1-1.5B,并与Falcon-H1-1.5B-deep相当,同时生成吞吐量 (generation throughput) 显著更高 (2,885 token/svs223 token/s或66 token/s)。Jet-Nemotron-4B在准确率和吞吐量上均优于两个Falcon-H1模型。
- 效率差距原因:
- 这种效率差距主要源于头级策略 (head-wise strategy) 的固有缺点。它要求在单个层内顺序计算 (sequential computation)
Mamba2和全注意力操作,从而限制了并行性 (parallelism)。 Falcon-H1的 "deep" 变体通过减小模型宽度来增加深度,进一步加剧了这一问题。这表明Jet-Nemotron采用的层级混合策略 (layer-wise hybrid strategy) 在实现高并行度和吞吐量方面更具优势。
- 这种效率差距主要源于头级策略 (head-wise strategy) 的固有缺点。它要求在单个层内顺序计算 (sequential computation)
7. 总结与思考
7.1. 结论总结
本文成功地提出了 Jet-Nemotron,一个新型的混合架构语言模型 (hybrid-architecture language models) 家族,解决了在不牺牲准确率的前提下提升 LLM 效率的长期挑战。其核心在于引入了后神经架构搜索 (Post Neural Architecture Search, PostNAS),一个创新且高效的架构探索流水线 (architecture exploration pipeline)。
Jet-Nemotron 的关键成就包括:
-
性能与效率双突破: 在
MMLU、数学推理、常识推理、检索、编码和长文本任务等广泛基准测试中,Jet-Nemotron能够匹配或超越Qwen3、Qwen2.5、Gemma3和Llama3.2等最先进的 (state-of-the-art) 全注意力模型的准确率。同时,它在H100 GPU上(256K 上下文长度,最大批次大小)实现了高达53.6倍的生成吞吐量 (generation throughput) 提升和6.1倍的预填充加速 (prefilling speedup)。 -
PostNAS 的关键作用:
PostNAS通过从预训练Transformer模型开始并冻结其 MLP 权重 (freezing its MLP weights),大幅降低了LLM架构探索的成本和风险,使其成为一种高效的架构适应流水线 (architecture adaptation pipeline)。 -
JetBlock 的创新贡献: 提出了
JetBlock,这是一种新颖的线性注意力块 (linear attention block),通过集成动态卷积 (dynamic convolution),在保持可比效率的同时,显著优于Mamba2、GLA和Gated DeltaNet等现有设计。 -
对架构设计的洞察: 实验结果揭示了注意力层任务特定重要性 (task-specific importance),并确认了
KV缓存大小 (KV cache size) 是影响长文本生成吞吐量 (long-context generation throughput) 的关键因素,而非简单的参数量。总而言之,
Jet-Nemotron以其卓越的准确率和高效的推理能力,为下一代高效LLM的开发和部署树立了新的标杆。
7.2. 局限性与未来工作
原文未明确指出 Jet-Nemotron 自身的局限性或提出未来的具体研究方向。然而,作为严谨的学术研究助理,我们可以从论文内容和领域发展趋势中进行推断和批判性思考:
- 对预训练基座模型的依赖:
PostNAS的核心优势在于利用了强大的预训练全注意力模型,并通过冻结 MLP 权重 (freezing MLP weights) 来降低成本。这意味着Jet-Nemotron的性能上限在一定程度上受限于其所选用的基座模型。如果基座模型本身存在偏见、知识不足或架构缺陷,PostNAS可能无法完全克服这些问题。未来工作可以探索如何将PostNAS扩展到从头开始训练的场景,或者在PostNAS过程中对MLP权重进行选择性微调,以适应更广泛的场景。 - JetBlock 的计算开销: 尽管
JetBlock声称引入动态卷积 (dynamic convolution) 带来了准确率提升且开销较小,但核生成器 (kernel generator) 模块的引入确实增加了额外的计算。在极端低功耗设备或对延迟有极高要求的场景下,这种额外的开销是否仍可忽略,以及动态卷积在处理超长序列时的实际效率瓶颈,可能需要更深入的分析和优化。 PostNAS的泛化性:PostNAS流程在一个预训练模型(如Qwen2.5)上进行了验证。该流程及其搜索出的最佳架构配置是否能普遍适用于其他不同的预训练LLM家族(如Llama系列、Mistral系列),或在领域外的数据集上保持其高效性,仍需进一步的实验验证。- 长上下文能力的进一步探索: 尽管
Jet-Nemotron在LongBench上表现良好,但长上下文任务仍然是LLM领域的开放挑战。目前模型仅在少数几层保留了全注意力 (full-attention),并利用滑动窗口注意力 (SWA) 和JetBlock来处理长上下文。未来可以探索更先进的混合策略,或者针对超长上下文(例如百万级词元)专门优化JetBlock的设计。 - 架构搜索的粒度:
PostNAS主要在层级和注意力块选择上进行搜索。未来可以探索更细粒度的架构搜索,例如在层内部的子组件级别进行混合,或动态调整不同注意力机制的比例。
7.3. 个人启发与批判
7.3.1. 个人启发
这篇论文提供了多方面的启发:
- “后训练”架构优化的潜力:
PostNAS范式是一个重大的思想转变。在LLM领域,预训练成本是最大的壁垒。PostNAS提出了一种在预训练完成后,以低成本、高效率进行架构创新的可行路径。这对于学术界和资源有限的研究团队来说,提供了一个极具价值的工具,使得他们也能参与到LLM架构的迭代和优化中。 - 效率与能力不再是简单的权衡:
Jet-Nemotron的成功证明了通过精巧的架构设计,我们可以在大幅提升效率的同时,不牺牲甚至超越最先进的 (state-of-the-art) 模型的准确率。这打破了传统上认为效率和能力之间是严格权衡的观念,激励我们去寻找更智能的解决方案。 - KV Cache 的关键性洞察: 论文明确指出
KV Cache大小是影响长上下文和长生成吞吐量的最关键因素,而不是参数量。这一洞察对于未来LLM的效率优化具有指导性意义,提醒我们优化方向不仅仅是计算量,更重要的是内存访问模式。 - 动态机制的价值:
JetBlock中引入的动态卷积 (dynamic convolution) 机制,使得模型能够根据输入动态调整其特征提取模式 (feature extraction patterns),这比静态卷积更加灵活和强大。这种动态性思维可以扩展到LLM架构的其他部分,例如动态调整层数、头数或不同机制的组合。 - 硬件感知设计的重要性: 将硬件感知 (hardware-awareness) 融入架构设计是至关重要的。模型性能的真正体现是在实际硬件上,而不是纯粹的理论计算量。
PostNAS的硬件感知搜索体现了这种实用主义,指导模型设计走向实际部署的优化。
7.3.2. 批判
尽管 Jet-Nemotron 和 PostNAS 取得了显著的进步,但仍有一些潜在问题和可改进之处值得批判性思考:
-
“冻结 MLP 权重”的合理性边界: 尽管冻结 MLP 权重 (freezing MLP weights) 显著降低了成本,但它也限制了模型在适应新架构时对多层感知器 (MLP) 内部知识的修改能力。如果新的注意力机制导致了显著不同的特征流,
MLP权重完全冻结可能无法实现全局最优。未来的研究可以探索选择性微调 (selective fine-tuning) 或部分解冻 (partial unfreezing)MLP权重的策略,以在成本和性能之间取得更好的平衡。 -
一次性超网络 (once-for-all supernetwork) 训练的复杂性: 虽然
一次性超网络降低了整体搜索成本,但其自身的训练过程可能仍具有一定的复杂性和超参数调优难度。例如,蒸馏损失 (distillation loss) 的选择、不同子网络采样策略对超网络泛化性的影响、以及如何确保超网络中的所有子网络都能得到有效训练,这些细节在论文中描述得相对简略。 -
JetBlock 动态卷积的泛用性:
JetBlock的动态卷积 (dynamic convolution) 仅应用于 词元,并移除了 和 上的静态卷积。虽然在当前实验中效果良好,但这种特定设计是否在所有任务、所有数据领域中都保持最优,或者是否存在更通用的动态机制应用于 和 甚至整个注意力流,值得进一步探讨。 -
可解释性与透明度:
PostNAS作为一个自动化的搜索过程,其最终产生的混合架构 (hybrid architecture)(哪些层是全注意力,哪些是SWA,哪些是JetBlock)可能缺乏直观的人工设计可解释性。了解为什么某些层对特定任务更重要,以及不同注意力机制在模型不同位置的协同作用原理,将有助于未来的模型设计。 -
长上下文基准的全面性:
LongBench是一个很好的长上下文基准,但长上下文能力是一个多维度的问题,包括对超长依赖的捕捉、噪声鲁棒性、事实性召回等。Jet-Nemotron在LongBench上的优秀表现是否能完全反映其在所有长上下文挑战下的能力,还需要更多样化的评估。总而言之,
Jet-Nemotron及其PostNAS为LLM领域注入了新的活力,尤其是在效率优化方面取得了突破。它为未来LLM的架构设计提供了宝贵的经验和方向,同时也为后续研究提出了新的挑战和机遇。
相似论文推荐
基于向量语义检索推荐的相关论文。