从兆字节到兆瓦:基于CUDA与Triton的高性能LLM与扩散内核全指南

12 minute read

Published:

所有内容均为LLM生成,请注意鉴别。

GPU计算范式:释放并行性能

深度学习的崛起,尤其是大语言模型(LLMs)和扩散模型,与图形处理单元(GPU)的发展密不可分。GPU最初为图形渲染而设计,因其大规模并行架构成为AI事实上的硬件选择。¹ 但真正发挥其威力,不仅仅是把代码跑在GPU上,更需要对硬件架构本身以及针对其设计的编程模型有根本性的理解。本节将从物理硬件到抽象编程模型,阐述GPU计算的基础原理,为理解高性能计算中为何某些编程模式有效提供必备上下文。

现代GPU的解剖:并行计算的能量房

现代GPU绝非仅是“更快的CPU”;它的架构根本不同,专为数据并行而非串行任务优化。² CPU拥有少量强大内核,追求低延迟的单线程性能;而GPU拥有上千个更简单的内核,追求高吞吐的并行计算。³ 要写出高效GPU代码,需理解其关键组成部分。

  • 流式多处理器(SMs): GPU被划分为一组流式多处理器,每个SM是一个独立的处理单元,包含成百上千的计算核心、调度单元和一池内存资源。SM能够并发执行多个线程块(thread block),是GPU并行的核心引擎。⁴
  • CUDA核心: 每个SM内部包含CUDA核心,本质上是执行浮点与整数运算的简单算术逻辑单元(ALU)。⁵ 大量CUDA核心使得一次能并行处理上千个操作。
  • 内存层级结构: 决定性能的关键是GPU的多级内存架构。各级内存空间的速度与容量差异,是性能瓶颈的主要来源。
  • 全局内存(DRAM): GPU最大容量的内存空间,在现代数据中心GPU上通常称作高带宽内存(HBM)。容量可达几十GB(如NVIDIA A100为40-80GB),但相较片上内存延迟最大、带宽最低,是GPU上的“系统内存”。
  • L2缓存: 大型缓存(A100上约40MB),所有SM共享,作为全局内存和SM之间高速缓冲,减少频繁数据访问的延迟。
  • 共享内存(SMEM): 每个SM内一片小型、可编程片上内存(A100为每SM 192KB),带宽极高、延迟极低(接近寄存器级)。所有线程块内线程可共享此内存,需编程者自行管理,是很多高性能CUDA模式的数据共享与通信基础。
  • 寄存器: GPU最快的内存,几乎零延迟。每个线程私有,用于存储局部变量。 片上内存(用于寄存器和共享内存的SRAM)与片外全局内存(DRAM/HBM)的性能差距极大。现代GPU可在一次全局内存读取期间完成上百次浮点运算。⁸ 这形成了著名的“内存墙”,应用瓶颈不再是算力(TFLOPS),而是高效喂数据给计算单元。因此,GPU编程的核心挑战不是单纯并行化计算,更在于协调数据流动,最大化在快速片上内存上计算的时间,最小化昂贵的全局内存往返。这一原则直接导致了自定义、硬件感知内核的需求。

CUDA编程模型:对硬件的抽象

NVIDIA的统一计算设备架构(CUDA)是一个并行计算平台和编程模型,提供了屏蔽硬件细节的软件层。³ 开发者可用C++、Python或Fortran等语言编写利用GPU强大并行能力的程序。¹ CUDA模型构建于一系列直接对应硬件结构的关键抽象之上。¹¹

  • 主机与设备(Host and Device): CUDA编程是异构的:CPU及其内存称为主机(host),GPU及其内存称为设备(device),两者物理分开,依靠PCIe总线连接。程序流程通常为主机统筹全局,在设备上分配内存、转移数据、发起GPU计算,再把结果拷会主机。⁴
  • 内核(Kernels):__global__修饰的函数,即CUDA C++中的GPU内核。² 单个内核能被大批线程并行执行。
  • 线程层级结构: CUDA采用三级线程层级组织大规模并行:
  • 线程(Thread): 最基本的执行单元。每个线程执行相同的内核代码,但操作的数据不同,用唯一索引区分。¹⁰
  • 线程块(Thread Block): 最多可含1024个线程,分配至同一SM,利用共享内存和同步(如__syncthreads())高效协作。紧密协作是诸多优化技术的基石。⁴
  • 网格(Grid): 由多个线程块组成,共同执行同一内核。各线程块彼此独立,不可直接通信,实现透明地跨GPU规模扩展——有更多SM的GPU可同时执行更多线程块。⁴
  • 索引机制: CUDA提供内建多维变量(threadIdx, blockIdx, blockDim, gridDim),线程可基于此计算自身在网格中的唯一身份标识,从而分配到需要处理的具体数据块。这个机制实现了上千线程间的高效工作分割。⁴

编写你的第一个CUDA内核:向量加法

并行计算中的“Hello, World!”通常是向量加法,即C = A + B,A、B、C为大向量。¹⁶ 在CUDA中实现它可展示基本的开发流程和编程模型。 大致步骤如下:⁵

  1. 主机端准备: 在CPU上通过malloc()分配主机向量h_Ah_Bh_C,并初始化输入。
  2. 设备内存分配: 在GPU上通过cudaMalloc()分配设备向量d_Ad_Bd_C。¹²
  3. 主机到设备数据传输:cudaMemcpy()(方向为cudaMemcpyHostToDevice)将输入数据拷贝到设备。⁵
  4. 内核发射: 用三角括号语法<<<gridDim, blockDim>>>在GPU上发射add_kernel,通常blockDim设为32的倍数(如256或512),gridDim用向上取整公式(N + blockDim.x - 1) / blockDim.x确保所有元素覆盖。¹⁵
  5. 内核执行: 所有线程并行运行内核代码,每个线程据全局索引对齐数据,并做越界检查防止操作非法内存。¹⁷
  6. 主机-设备同步: 主机必须等待GPU计算完成才能获取结果,通过cudaDeviceSynchronize()阻塞主机直到所有GPU命令结束。¹⁸
  7. 设备到主机数据回传:cudaMemcpy()(方向为cudaMemcpyDeviceToHost)回拷结果。⁵
  8. 清理:cudaFree()free()分别释放设备和主机内存,防止泄漏。¹² 以下代码为完整CUDA C++向量加法示例。CUDA源码通常用.cu扩展名,并用NVIDIA的nvcc编译(如nvcc vector_add.cu -o vector_add)。¹²
    #include <iostream>
    #include <cmath>
    // CUDA API错误检查封装
    void cudaCheck(cudaError_t error, const char *file, int line) {
     if (error!= cudaSuccess) {
         printf(" at %s:%d: %s\n", file, line, cudaGetErrorString(error));
         exit(EXIT_FAILURE);
     }
    }
    #define CUDA_CHECK(err) (cudaCheck(err, __FILE__, __LINE__))
    // CUDA向量加法内核
    __global__ void add_kernel(const float* a, const float* b, float* c, int n) {
     int idx = blockIdx.x * blockDim.x + threadIdx.x;
     if (idx < n) {
         c[idx] = a[idx] + b[idx];
     }
    }
    int main() {
     int N = 1 << 20; // 1,048,576 元素
     size_t size = N * sizeof(float);
     // 1. 主机内存分配
     float* h_a = (float*)malloc(size);
     float* h_b = (float*)malloc(size);
     float* h_c = (float*)malloc(size);
     // 初始化主机向量
     for (int i = 0; i < N; ++i) {
         h_a[i] = sin(i) * sin(i);
         h_b[i] = cos(i) * cos(i);
     }
     // 2. 设备内存分配
     float *d_a, *d_b, *d_c;
     CUDA_CHECK(cudaMalloc(&d_a, size));
     CUDA_CHECK(cudaMalloc(&d_b, size));
     CUDA_CHECK(cudaMalloc(&d_c, size));
     // 3. 主机->设备数据传输
     CUDA_CHECK(cudaMemcpy(d_a, h_a, size, cudaMemcpyHostToDevice));
     CUDA_CHECK(cudaMemcpy(d_b, h_b, size, cudaMemcpyHostToDevice));
     // 4. 启动内核
     int threadsPerBlock = 256;
     int blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;
     add_kernel<<<blocksPerGrid, threadsPerBlock>>>(d_a, d_b, d_c, N);
     // 5. 同步并检查内核
     CUDA_CHECK(cudaGetLastError());
     CUDA_CHECK(cudaDeviceSynchronize());
     // 6. 设备->主机结果回传
     CUDA_CHECK(cudaMemcpy(h_c, d_c, size, cudaMemcpyDeviceToHost));
     // 在主机端验证结果
     float maxError = 0.0f;
     for (int i = 0; i < N; ++i) {
         maxError = fmax(maxError, fabs(h_c[i] - 1.0f));
     }
     std::cout << "Max error: " << maxError << std::endl;
     // 7. 释放内存
     cudaFree(d_a);
     cudaFree(d_b);
     cudaFree(d_c);
     free(h_a);
     free(h_b);
     free(h_c);
     return 0;
    }
    

Triton:Python中的高生产力GPU编程

CUDA固然能完全控制GPU,但其复杂度极高。要写出高效的CUDA代码需深入的硬件知识,开发过程既耗时又易出错。²¹ AI研发节奏极快,因此迫切需要更高效的方案。OpenAI的Triton正是在这个背景下应运而生,为高性能GPU计算提供了Python化的高层路径。

弥合鸿沟:Triton的缘起

Triton是为AI社区打造的开源语言与编译器²³,目标是让从未接触过CUDA的开发者也能写出性能媲美专家级CUDA、开发成本却仅十分之一的自定义计算内核。²¹ Triton的魔法源于其JIT编译器。开发者以类似Python的高阶语法写内核逻辑,编译器自动处理传统CUDA中极具挑战的细节²¹:

  • 自动内存合并: 编译器分析内存访问模式,并将其整合为单次高效事务,最大化带宽利用率。
  • 共享内存自动管理: 自动将片上高速共享内存用作线程块级缓存,不再需复杂手动管理。
  • 指令调度: 编译器自动重新排序指令以隐藏内存延迟,提升SM执行单元的利用率。 这种高度自动化极大“民主化”了高性能计算。过去优化内核只属于为数不多的HPC专家。如今LLM及各类大模型爆发,对算力及优化需求极大激增。²⁷ Triton让懂Python的数据科学家、ML工程师均可编写高性能内核,打通了PyTorch等高层框架与底层CUDA间的壁垒,大大加速了行业生态开放与敏捷创新²⁴,性能优化不再是巨头专利。²⁵

Triton编程模型:块级抽象

Triton编程模型最大特色是提升SPMD(单程序多数据)抽象层级。与CUDA不同,开发者不再聚焦单一标量线程,而是编写面向整个数据块(tile/block)操作的块级程序。²¹ 块内线程完全由编译器托管,开发者无需关心。 核心概念包括:

  • @triton.jit装饰器: 标记Triton内核,激活JIT编译至GPU。²⁴
  • 指针: 内核操作GPU内存地址指针,而不是tensor对象。PyTorch tensor传入会隐式转化为指针³¹。
  • 程序ID与偏移量: 启动时生成N维网格,内核通过tl.program_id(axis=...)获取自身索引,用于计算要处理数据块的基址。²⁴
  • 块级操作(tl.arange, tl.load, tl.store): 以块(vector)为单位操作数据,tl.arange(start, end)生成索引,配合基础偏移量计算全块内存指针,tl.loadtl.store直接批量装/存到片上SRAM。²⁹
  • 掩码机制: tl.loadtl.store支持bool掩码参数,高效处理不能整除块大小的边界,远优于分支判断防止线程发散、掉速。²⁴
  • tl.constexpr 编译时常量标记,如块大小。编译器得以展开循环,将维度硬编码至设备指令中,生成更高效代码。²⁹

Triton“Hello, World!”:向量加法重温

用Triton重写向量加法,一目了然见证其高层抽象与简洁对比。 全流程分两步:Python主机函数发射内核,Triton内核实现核心逻辑。

  1. Python主机包装器(add函数): 接受PyTorch张量输入,预分配输出张量,定义launch grid(计算需要多少个程序实例),发射内核并传递参数及meta参数。³²
  2. Triton内核(add_kernel函数):@triton.jit装饰,计算自身程序ID、块内偏移量、边界掩码,然后用tl.load加载并用向量加法后tl.store存回数据。²⁴

完整且附注释的Python代码如下:

import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
    x_ptr,  # 第一个输入向量指针
    y_ptr,  # 第二个输入向量指针
    output_ptr,  # 输出向量指针
    n_elements,  # 向量总元素数
    BLOCK_SIZE: tl.constexpr,  # 每个程序处理的元素数
):
    """
    Triton向量逐元素加法内核。
    每个实例处理输出向量的一个块。
    """
    # 1. 获取本实例程序ID
    pid = tl.program_id(axis=0)
    # 2. 计算本块处理元素的偏移量
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 3. 构造掩码防止越界
    mask = offsets < n_elements
    # 4. 从全局内存加载数据到寄存器
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    # 5. 逐块加法
    output = x + y
    # 6. 写回结果
    tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Triton向量加法内核的主机端发射函数。
    """
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta),)
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output
# --- 验证 ---
torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(f"PyTorch output: {output_torch}")
print(f"Triton output:  {output_triton}")
print(f"The maximum difference is {torch.max(torch.abs(output_torch - output_triton))}")

可见,Triton实现不仅代码短,用更友好的Python语法,还屏蔽了手动内存管理、线程索引等繁琐流程。

表1:CUDA vs. Triton - 对比总览

下表总结两者关系,其实不是竞品,而是优化工具箱的互补利器。

特性CUDATriton
编程语言C/C++带扩展类Python DSL
抽象层次底层、硬件导向高层、块导向
线程管理手动(线程、warp、块)自动(由编译器管理)
内存管理手动(cudaMalloc__shared__共享内存)自动(编译器优化SRAM用法)
学习曲线陡峭,需深厚硬件知识适中,Python开发者友好
调试复杂(Nsight、printf等)简易(可用CPU模拟器TRITON_INTERPRET=1
极致性能可达绝对峰值性能通常媲美专家级CUDA,优质性能更易实现
典型场景新硬件特性、复杂/非常规内核普通DL算子融合、快速原型开发、通用内核开发

性能瓶颈与自定义内核的重要性

尽管PyTorch、TensorFlow等提供高度优化的基础算子库,但其标准执行模型常留大量性能空间未被利用。深入理解是区分受计算/内存限制型操作,及框架如何处理二者,正是自定义CUDA与Triton内核的根本动因。

内存瓶颈 vs. 计算瓶颈

所有GPU操作的性能,最终受两大要素约束:计算速率(计算瓶颈)或数据访存带宽(内存瓶颈)。³⁵

  • 计算受限操作: 大部分时间花在算术指令上,瓶颈在SM的“浮点能力”(FLOPS)。典型如大规模稠密矩阵乘(GEMM),即线性层与注意力机制的主力,计算量远超数据量。³⁵
  • 内存受限操作: 等待数据从慢速全局内存传输到片上内存(SRAM)的时间压倒实际计算。网络中大部分非GEMM操作属于此类,如逐元素操作(ReLU、GeLU、dropout、加法)、归一化层(LayerNorm、BatchNorm)、归约(softmax、sum、max)等。⁶ 深度学习框架标准模式,会为每个操作独立发射内核。例如线性层 -> ReLU序列,框架会先发射一次GEMM(将结果写入HBM),再对结果读回来、施加ReLU、再写出。这种“操作粒度细”的调度,极大增加了对HBM的不必要读写,导致性能严重损耗。³⁷

DRAM暴政与内核发射开销

这种低效,根源在硬件现实:

  • HBM瓶颈: HBM带宽(A100约2TB/s)远低于片上SRAM(约19TB/s),每一次不必要的HBM往返都导致算单元空转,必须极力减少这种数据流动。
  • 内核发射开销: 从CPU发起内核需穿过PCIe总线、有一定设置与等待时间。对于单个长内核(如GEMM)影响不大,但若模型由成千上万个小快核构成,这部分累计开销可能大于实际计算本身。³⁸

融合的力量:Kernel Fusion

内核融合是应对上述瓶颈的基本优化法。它将多个顺序操作合为一个更大内核,³⁸ 如将Linear -> ReLU融合,线程块计算矩阵乘结果后立即在寄存器做ReLU激活,只有最终结果才写回HBM。 核心收益有:⁴⁰

  • 大幅减少HBM带宽消耗:不用中间张量反复进出HBM
  • 消除发射开销:原本多次发射成为一次
  • 提升数据局部性与cache效率:数据“热”时即复用,充分利用最快内存层级 融合已成为现代AI性能优化的中流砥柱,Flash Attention等也本质是对注意力模块操作序列高级别内核融合的实践。³⁶ Liger-Kernel、DeepSpeed Transformer Kernel也都是高度融合后的算子内核集合。²⁷

Transformer优化实战:从多头注意力到Flash Attention

Transformer架构及其自注意力机制,是现代LLM的核心,同时也是性能瓶颈。本节以实操代码为中心,从朴素实现到Flash Attention,深入剖析这一关键组件的优化路径。

拆解多头注意力(MHA)

注意力机制能够让模型在生成每个token表示时,权衡序列中不同token的重要性。最常用的是缩放点积注意力(Scaled Dot-Product Attention): \(Attention(Q, K, V) = \mathrm{softmax}\left(\frac{Q K^{T}}{\sqrt{d_k}}\right)V\) 其中Q/K/V均由输入投影而来,d_k为key维度。多头注意力(MHA)则并行运行多组独立的投影与注意力,然后拼接。⁴⁵ 通常PyTorch实现为:

  1. 三个nn.Linear分别用于Q、K、V投影
  2. 对tensor重排列分头
  3. torch.bmm@计算QKᵀ
  4. 应用mask
  5. 对分数softmax
  6. 用权重与V矩阵做第二次bmm
  7. 输出转换回原状 最大低效在于材料化了(S=QKᵀ)这一大矩阵,其维度为(N,N),内存消耗O(N²),序列较长则迅速失控,对HBM压力极大。⁴⁷

Triton融合MHA内核

第一步优化可用Triton融合MHA多个部分,节省HBM与发射开销。大致思路:

  1. QKV融合投影: 用一个宽度为3倍的nn.Linear替代三个,减少GEMM次数。⁵⁰
  2. 融合注意力核心: 写单个Triton内核,顺序包含softmax(QKᵀ)V,按块加载QKV,在片上算分数、softmax并乘以V,所有步骤均在寄存器/SMEM执行,结果最后才写回。如此不用在全局内存中存储KQᵀ矩阵。

Triton简化版示意代码如下:

@triton.jit
def fused_attn_kernel(Q_ptr, K_ptr, V_ptr, O_ptr,...):
    pid_m = tl.program_id(axis=0) # 行块编号
    pid_n = tl.program_id(axis=1) # 列块编号(简版可忽略)
    acc = tl.zeros((BLOCK_SIZE_M, D_HEAD), dtype=tl.float32)
    q_offsets =...
    q = tl.load(Q_ptr + q_offsets)
    for k_block_idx in range(0, N_CTX, BLOCK_SIZE_K):
        k_offsets =...
        k = tl.load(K_ptr + k_offsets)
        s_ij = tl.dot(q, k)
        # Flash Attention在线softmax本应在此实现
        p_ij = tl.softmax(s_ij)
        v_offsets =...
        v = tl.load(V_ptr + v_offsets)
        acc += tl.dot(p_ij, v)
    o_offsets =...
    tl.store(O_ptr + o_offsets, acc)

但注意简单版本还是无法不完整存全行分数而正确算softmax,这正是Flash Attention彻底解决的问题。

Flash Attention革命

Flash Attention是一种IO感知、完全精确的注意力算法,通过计算重排序解决O(N²)内存问题,堪称算法-硬件协同设计典范。⁵¹ 关键技巧:

  • 分块(Tiling): Q/K/V都分成小块,按块加载到SRAM,经过所有K/V块,实现数据始终在片上最快内存。
  • 在线softmax: 核心创新是只用部分分数逐步更新最大值m与归一化sum l—— \(m_{\text{new}} = \max \left( m_{\text{old}},\, \max(S_{ij}) \right)\) \(l_{\text{new}} = e^{(m_{\text{old}} - m_{\text{new}})} \cdot l_{\text{old}} + \sum e^{(S_{ij} - m_{\text{new}})}\) 数学技巧保证任何时刻都能分块、流式地计算本应全局的操作。³⁶
  • 反向重计算: 只需保存最终O与(l, m)统计,无需保存整个注意力矩阵;反向传播通过这些数据片上重算所需内容,极大缓解存储与流量。 这种设计表明,最大性能提升往往来自根本的算法重构——把运算流程对齐物理硬件。

Triton实现Flash Attention

Triton官方教程包含Flash Attention v2内核详解。⁵⁶ 其核心要素如下:

  • 网格与程序ID: 用2D/3D网格并行batch/head/seq
  • 外循环: 遍历Q分块
  • 内循环: 遍历K/V分块
  • 数据加载: 在内循环中将片加载至寄存器
  • 分数计算: tl.dot(q, k)
  • 在线Softmax: 用新最大值/累加和alpha等公式持续修正
  • 因果mask: 如需因果注意力,对未来token分数设负无穷
  • 最终写回: 内循环后统一写出
  • 自动调优:@triton.autotune装饰器,自动测试不同参数并缓存最佳配置,省去手动调优。

实际性能有重大提升,见下表:

表2:不同注意力实现性能对比(概念性)

以A100为例,表述不同实现的性能提升趋势:

实现方式序列长度延迟(ms)峰值显存(GB)
朴素PyTorch MHA512~1.5~0.5
朴素PyTorch MHA2048~20~8
朴素PyTorch MHA8192OOM>80 (OOM)
融合Triton MHA512~1.2~0.5
融合Triton MHA2048~15~8
融合Triton MHA8192OOM>80 (OOM)
Flash Attention(Triton)512~0.5<0.1
Flash Attention(Triton)2048~2.0<0.1
Flash Attention(Triton)8192~8.0<0.1
Flash Attention(Triton)16384~18.0<0.1
备注:OOM = 内存溢出(Out of Memory),数据为示意估值。   
可以看到,基本融合仅带小幅提升,并未解决O(N²)内存障碍;Flash Attention则真正实现数量级加速与突破长序列限制。   

端到端模型优化

单一算子级的优化很重要,但整体性能飞跃源于对模型整体端到端的融合与优化。例如把Transformer一整个block,或Diffusion U-Net的主干,进行更大粒度的融合,超越单算子内核。

融合完整Transformer Block

典型的Transformer block包括多头注意力、残差连接、LayerNorm、前馈网络(MLP/FFN)、再接LayerNorm。⁴⁷ 朴素实现每步各发射一次内核,产生大量内存流动与调度消耗。 极致目标是“Megakernel”,即将尽可能多的顺序操作融合到同一个持久内核里,⁵⁸ 这样几乎无需多次发射,且数据总停留在片上内存。 实践中可按功能拆为更易维护的少量高度融合内核:

  • Flash Attention内核: 作为高效MHA部分
  • 融合MLP/FFN内核: MLP一般为Linear -> GeLU -> Linear,可用单Triton内核实现前后乘、激活全在寄存器完成,极大节省HBM读写。⁵⁹
  • 融合LayerNorm: LayerNorm本质是归约+缩放,常和残差连接及前面操作一并融合,甚至如DeepSpeed的“可逆LayerNorm”允许输入激活丢弃,仅由输出和统计重算,进一步省内存。⁴⁴

下面是常见的Triton融合MLP内核简例:

@triton.jit
def fused_mlp_kernel(
    x_ptr, w1_ptr, w2_ptr, output_ptr,
    M, N, K,
    stride_xm, stride_xn,
    stride_w1n, stride_w1k,
    stride_w2k, stride_w2n,
    stride_om, stride_on,
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
    # --- 融合Linear -> GeLU -> Linear ---
    pid = tl.program_id(axis=0)
    # --- 第一个线性变换 ---
    # 实质为标准的分块矩阵乘,得中间结果
    intermediate_block =... # 第一次GEMM结果
    # --- GeLU激活 ---
    # 在寄存器中对中间结果做GeLU近似
    activated_block = 0.5 _intermediate_block_ (1 + tl.tanh(0.79788456 _intermediate_block_ (1 + 0.044715 _intermediate_block_ intermediate_block)))
    # --- 第二个线性变换 ---
    output_block =... # 第二次GEMM结果
    tl.store(output_ptr + output_offsets, output_block)

优化Diffusion模型

扩散模型(Diffusion),在图像/视频生成领域盛行,普遍用U-Net主干实现迭代去噪。⁶⁰ U-Net结构的重复块尤其适合融合。 常见计算模式为卷积 -> 归一化 -> 激活,如1D卷积 + GroupNorm + Mish/SiLU激活。³⁹ 典型内存瓶颈,适合自定义Triton内核三合一,实现归约、归一化,激活全部寄存器级完成。 进一步,标准扩散推理过程中,timestep嵌入及大量噪声常数重复地每步重算。高阶优化是预先算好所有这些常数存GPU,去噪内核仅按需索引,显著减少冗余并减轻内核负担。⁶²

生态工具:torch.compile及自定义内核使用场合

今日优化工具已极为丰富,不再是高层框架与低级CUDA二选一。

  • torch.compile自动优化: PyTorch 2.0推出的JIT编译器,可自动捕获模型算子图并用TorchInductor后端融合与生成高效Triton内核。³⁷ 多数标准网络只需加 @torch.compile就能获得大幅自动融合与加速,免于手动写核。⁶⁵
  • 手写内核的必然: 自动编译器再强,以下情况必须自定义核:
  • 新颖结构与操作: 如设计新型注意力机制、自定义激活/归一化,自动编译器无法识别融合,只有手写核才能高效。
  • 最前沿算法: Flash Attention这类极度复杂与特殊的数据流和算法结构,编译器目前完全无法自动发现实现,需手工。
  • 极致生产性能: 服务型推理场景对毫秒级延迟和极致能效敏感,顶级手写内核往往能超编译器生成的10-20%。
  • 新硬件特性: 如NVIDIA Hopper的Tensor Memory Accelerator等新硬件,只有手写核能第一时间吃上红利。 这也决定了现代性能工程的分层策略:首选高阶工具,只有发现无法满足性能瓶颈时才逐级下探至Triton/CUDA。

表3:Transformer Block端到端优化战略

下表给出典型Transformer Block关键部位的优化分解与主流实现工具。

组件融合手段主要受益工具
Q、K、V投影融合nn.Linear3次GEMM合一、提升GPU利用率PyTorch (torch.nn.Linear(d_in, 3*d_out))
缩放点积注意力Flash Attention彻底摒弃N²内存,极大压缩HBM流量Triton/CUDA (如flash-attn库)
MLP/FFN融合Linear-GELU-Linear激活全片上执行,防止中间态反复进HBMTriton, torch.compile
LayerNorm融合前级op(如残差加法)精简HBM流量Triton, torch.compile, DeepSpeed
整个Transformer层Megakernel/持久化内核层内发射开销几乎为零手写CUDA/Triton, TensorRT-LLM

总结与未来趋势

从理解GPU硬件到底层Triton端到端融合内核,实现高性能的原则明确:算力巨大但受制于内存带宽,内核定制与尤其关注内存局部性与融合,是解决矛盾的主力工具。

ML实践者行动指南

对数据科学家与ML工程师来说,精通高性能计算的路径,不是逃离高层而拥抱纯CUDA,而是以分层、策略化方式高性价比进阶:

  • 1步:精通高层框架。 首先深入掌握PyTorch等框架结构会用Profile工具(如PyTorch Profiler、NVIDIA Nsight)定位瓶颈。
  • 2步:善用编译器。 第一优选torch.compile,能自动处理多数常见融合,一行代码换大提速。
  • 3步:学习Triton自定义融合。 Profile后发现自动编译遗漏或需自研新层,转向Triton,用其块级模型写特定热点核。
  • 4步:钻研SOTA算法。 对如关注力这类有高成熟SOTA的,不要闭门造车,多研究Flash Attention等现有实现,体会算法-硬件协同的要义。
  • 5步:谨慎用CUDA。 仅当需对接底层库、Triton不适于高度创新算法或极致生产优化时才直接写CUDA。

高性能AI的未来

AI性能工程快速进化中,数大发展趋势愈发显著:

  • 编译器霸主地位上升: 高层框架日益成为多后端强编译器的友好前端。²⁵ 如Triton与PyTorch深度集成的torch.compile趋势已极为明显。
  • 对新硬件特性的持续追逐: NVIDIA等不断迭代新GPU及定制单元(如FP8、TMA),新型内核开发需求不断。Triton和CUDA依然是探索先锋。
  • 抽象促普惠: 底层硬件和内核愈发复杂,但高层API/抽象接口让优化触手可及。例如FlashInfer等将高性能内核封装至易用接口,助力框架开发者无需精通CUDA也能用顶级实现。⁶⁸ 终极来看,掌握CUDA和Triton不仅是模型性能所需,更是通向并行计算、数据局部性、存储体系底层原理的通用能力。这些知识即便未来模型与框架更迭,依然具备持久价值。AI继续逼近算力极限,能在算法与硬件“桥接”者必将需求更旺盛。

参考文献

  1. Understanding CUDA in Computing: A Comprehensive Guide | Lenovo US, accessed July 17, 2025, https://www.lenovo.com/us/en/glossary/what-is-the-cuba-toolkit/

  2. What is the best way for beginners to learn CUDA and parallel computing with GPUs?, accessed July 17, 2025, https://www.quora.com/What-is-the-best-way-for-beginners-to-learn-CUDA-and-parallel-computing-with-GPUs

  3. What Is CUDA? - Supermicro, accessed July 17, 2025, https://www.supermicro.com/en/glossary/cuda

  4. Introduction to CUDA Programming - GeeksforGeeks, accessed July 17, 2025, https://www.geeksforgeeks.org/electronics-engineering/introduction-to-cuda-programming/

  5. CUDA Refresher: The CUDA Programming Model | NVIDIA Technical Blog, accessed July 17, 2025, https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/

  6. Simplifying CUDA kernels with Triton: A Pythonic Approach to GPU Programming, accessed July 17, 2025, https://arunjitha.medium.com/simplifying-cuda-kernels-with-triton-a-pythonic-approach-to-gpu-programming-79bb7121e974

  7. Flash attention(Fast and Memory-Efficient Exact Attention with IO-Awareness): A deep dive, accessed July 17, 2025, https://towardsdatascience.com/flash-attention-fast-and-memory-efficient-exact-attention-with-io-awareness-a-deep-dive-724af489997b/

  8. FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness - OpenReview, accessed July 17, 2025, https://openreview.net/pdf?id=H4DqfPSibmx

  9. 1 CUDA Programming Model, accessed July 17, 2025, https://www.eng.utah.edu/~cs5610/lectures/Programming_Models_for_GPU_Architecture%20CUDA.pdf

  10. en.wikipedia.org, accessed July 17, 2025, https://en.wikipedia.org/wiki/CUDA

  11. kst179/fused-attention: Fast and low-memory attention layer … - GitHub, accessed July 17, 2025, https://github.com/kst179/fused-attention

  12. CUDA Zone - Library of Resources - NVIDIA Developer, accessed July 17, 2025, https://developer.nvidia.com/cuda-zone

  13. What is the CUDA Programming Model? | GPU Glossary - Modal, accessed July 17, 2025, https://modal.com/gpu-glossary/device-software/cuda-programming-model

  14. Tutorial 01: Say Hello to CUDA, accessed July 17, 2025, https://cuda-tutorial.readthedocs.io/en/latest/tutorials/tutorial01/

  15. An Even Easier Introduction to CUDA (Updated) | NVIDIA Technical Blog, accessed July 17, 2025, https://developer.nvidia.com/blog/even-easier-introduction-cuda/

  16. CUDA Programming - Wolfram Language Documentation, accessed July 17, 2025, https://reference.wolfram.com/language/CUDALink/tutorial/Programming.html

  17. CUDA Programming Model — MolSSI GPU Programming Fundamentals documentation, accessed July 17, 2025, https://education.molssi.org/gpu_programming_beginner/03-cuda-program-model.html

  18. CUDA Basic Example - Vector Addition Explanation - eunomia, accessed July 17, 2025, https://eunomia.dev/others/cuda-tutorial/01-vector-addition/

  19. Vector Addition “Hello World!” Example with CUDA on Mac OSX …, accessed July 17, 2025, https://www.quantstart.com/articles/Vector-Addition-Hello-World-Example-with-CUDA-on-Mac-OSX/

  20. olcf-tutorials/vector_addition_cuda: A simple CUDA vector addition program - GitHub, accessed July 17, 2025, https://github.com/olcf-tutorials/vector_addition_cuda

  21. 4.4 Example: Vector Addition — Parallel Computing for Beginners, accessed July 17, 2025, https://www.learnpdc.org/PDCBeginners/4-cuda/4-VectorAdd.html

  22. Tutorial 02: CUDA in Actions, accessed July 17, 2025, https://cuda-tutorial.readthedocs.io/en/latest/tutorials/tutorial02/

  23. Introducing Triton: Open-Source GPU Programming for Neural Networks, accessed July 17, 2025, https://aimersociety.com/introducing-triton-open-source-gpu-programming-for-neural-networks/

  24. How Is OpenAI’s Triton Different From NVIDIA CUDA? - Analytics India Magazine, accessed July 17, 2025, https://analyticsindiamag.com/global-tech/how-is-openais-triton-different-from-nvidia-cuda/

  25. triton-lang.org, accessed July 17, 2025, https://triton-lang.org/#:~:text=Triton%20is%20a%20language%20and,throughput%20on%20modern%20GPU%20hardware.

  26. Democratizing AI Accelerators and GPU Kernel Programming using Triton, accessed July 17, 2025, https://next.redhat.com/2024/11/07/democratizing-ai-accelerators-and-gpu-kernel-programming-using-triton/

  27. Exploring Triton GPU programming for neural networks in Java - OpenJDK, accessed July 17, 2025, https://openjdk.org/projects/babylon/articles/triton

  28. Getting Started with Triton: A Step-by-Step Tutorial - Medium, accessed July 17, 2025, https://medium.com/ai-insights-cobet/getting-started-with-triton-a-step-by-step-tutorial-ddc18a186295

  29. Liger Kernel: Efficient Triton Kernels for LLM Training - arXiv, accessed July 17, 2025, https://arxiv.org/html/2410.10989v3

  30. [D] usefulness of learning CUDA/triton : r/MachineLearning - Reddit, accessed July 17, 2025, https://www.reddit.com/r/MachineLearning/comments/1kewrqc/d_usefulness_of_learning_cudatriton/

  31. Triton — GPU Programming for Neural Networks | by Dhananjay Kumar - Medium, accessed July 17, 2025, https://dhnanjay.medium.com/triton-gpu-programming-for-neural-networks-16271d729f78

  32. Introduction - Triton documentation, accessed July 17, 2025, https://triton-lang.org/main/programming-guide/chapter-1/introduction.html

  33. GPU MODE Lecture 14: Practitioners Guide to Triton - Christian Mills, accessed July 17, 2025, https://christianjmills.com/posts/cuda-mode-notes/lecture-014/

  34. triton_tutorial/02_vector_addition.ipynb at master - GitHub, accessed July 17, 2025, https://github.com/VikParuchuri/triton_tutorial/blob/master/02_vector_addition.ipynb

  35. Vector Addition — Triton documentation, accessed July 17, 2025, https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html

  36. Triton Vector Addition Kernel, part 1: Making the Shift to Parallel Programming - YouTube, accessed July 17, 2025, https://www.youtube.com/watch?v=MEZ7XhzTLEg

  37. lectures/lecture_014/A_Practitioners_Guide_to_Triton.ipynb at main - GitHub, accessed July 17, 2025, https://github.com/gpu-mode/lectures/blob/main/lecture_014/A_Practitioners_Guide_to_Triton.ipynb

  38. How Nvidia’s CUDA Monopoly In Machine Learning Is Breaking – OpenAI Triton And PyTorch 2.0 – SemiAnalysis, accessed July 17, 2025, https://semianalysis.com/2023/01/16/nvidiaopenaitritonpytorch/

  39. Flash Attention - Insu Jang, accessed July 17, 2025, https://insujang.github.io/2024-01-21/flash-attention/

  40. OpenAI’s Triton: An end to end example | by Michael Diggin | Medium, accessed July 17, 2025, https://medium.com/@michael.diggin/openais-triton-an-end-to-end-example-c6577d81e3d0

  41. Kernel Fusion - Steven Gong, accessed July 17, 2025, https://stevengong.co/notes/Kernel-Fusion

  42. Part VI - Kernel Fusion in CUDA - Vrushank Desai, accessed July 17, 2025, https://www.vrushankdes.ai/diffusion-policy-inference-optimization/part-vi—kernel-fusion-in-cuda

  43. 31. Kernel Fusion - Aussie AI, accessed July 17, 2025, https://www.aussieai.com/book/ch31-kernel-fusion

  44. Kernel Operator Fusion - Aussie AI, accessed July 17, 2025, https://www.aussieai.com/research/kernel-fusion

  45. Kernel Fusion: A Smart Way to Enhance Neural Networks Performance - abhik.xyz, accessed July 17, 2025, https://www.abhik.xyz/articles/kernel-fusion

  46. DeepSpeed Transformer Kernel - DeepSpeed, accessed July 17, 2025, https://www.deepspeed.ai/tutorials/transformer_kernel/

  47. FlashAttention: Implementing High-Performance Attention with CUDA and Triton - Medium, accessed July 17, 2025, https://medium.com/@kimdoil1211/flashattention-implementing-high-performance-attention-with-cuda-and-triton-9ee635ab1200

  48. Tutorial 6: Transformers and Multi-Head Attention — UvA DL Notebooks v1.2 documentation, accessed July 17, 2025, https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html

  49. Parallelizing Multi-Head Attention on GPUs - Hemil Desai, accessed July 17, 2025, https://hd10.dev/posts/my-interests-2/cs259.pdf

  50. Multi-Head Attention From Scratch - Sanjaya’s Blog, accessed July 17, 2025, https://sanjayasubedi.com.np/deeplearning/multihead-attention-from-scratch/

  51. Introduction to Flash Attention: A Breakthrough in Efficient Attention Mechanism | by Sthanikam Santhosh | Medium, accessed July 17, 2025, https://medium.com/@sthanikamsanthosh1994/introduction-to-flash-attention-a-breakthrough-in-efficient-attention-mechanism-3eb47e8962c3

  52. Flash Attention: Revolutionizing Transformer Efficiency - Unite.AI, accessed July 17, 2025, https://www.unite.ai/flash-attention-revolutionizing-transformer-efficiency/

  53. LLMs-from-scratch/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb at main - GitHub, accessed July 17, 2025, https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb

  54. FlashAttention: Fast and Memory-Efficient Exact … - deepsense.ai, accessed July 17, 2025, https://deepsense.ai/wp-content/uploads/2023/04/2205.14135.pdf

  55. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. - MIT MLSys Discussion Group, accessed July 17, 2025, https://www.mlsys.ai/papers/flash_attention.html

  56. Understanding Flash Attention: Writing the Algorithm from Scratch in Triton, accessed July 17, 2025, https://towardsdatascience.com/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton-5609f0b143ea/

  57. FLASH ATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness: Paper Review | by Sulbha Jain | May, 2025 | Medium, accessed July 17, 2025, https://medium.com/@sulbha.jindal/flash-attention-fast-and-memory-efficient-exact-attention-with-io-awareness-paper-review-79639127c5de

  58. Fused Attention - Triton documentation, accessed July 17, 2025, https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html

  59. Tutorial #17: Transformers III Training - Research Blog | RBC Borealis, accessed July 17, 2025, https://rbcborealis.com/research-blogs/tutorial-17-transformers-iii-training/

  60. Compiling LLMs into a MegaKernel: A Path to Low-Latency …, accessed July 17, 2025, https://zhihaojia.medium.com/compiling-llms-into-a-megakernel-a-path-to-low-latency-inference-cf7840913c17

  61. flash-attention/training/README.md at main · Dao-AILab/flash …, accessed July 17, 2025, https://github.com/HazyResearch/flash-attention/blob/main/training/README.md

  62. Custom diffusion model with PyTorch — Tutorials for AI developers 4.0, accessed July 17, 2025, https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/notebooks/pretrain/ddim_pretrain.html

  63. Ultimate guide to optimizing Stable Diffusion XL - Félix Sanz, accessed July 17, 2025, https://www.felixsanz.dev/articles/ultimate-guide-to-optimizing-stable-diffusion-xl

  64. Part VII - A Dive Into DDPMs & CUDA kernel for Denoising - Vrushank Desai, accessed July 17, 2025, https://www.vrushankdes.ai/diffusion-policy-inference-optimization/part-vii—a-dive-into-ddpms-cuda-kernel-for-denoising

  65. Liger Kernel: Efficient Triton Kernels for LLM Training - arXiv, accessed July 17, 2025, https://arxiv.org/html/2410.10989v2

  66. Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile() — PyTorch Tutorials 2.7.0+cu126 documentation, accessed July 17, 2025, https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html

  67. Accelerating Generative AI Part III: Diffusion, Fast – PyTorch, accessed July 17, 2025, https://pytorch.org/blog/accelerating-generative-ai-3/

  68. FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision - PyTorch, accessed July 17, 2025, https://pytorch.org/blog/flashattention-3/

  69. How do you optimize GPU utilization during diffusion model training? - Milvus, accessed July 17, 2025, https://milvus.io/ai-quick-reference/how-do-you-optimize-gpu-utilization-during-diffusion-model-training

  70. What is Flash Attention? | Modal Blog, accessed July 17, 2025, https://modal.com/blog/flash-attention-article

  71. Optimizing Transformer-Based Diffusion Models for Video Generation with NVIDIA TensorRT, accessed July 17, 2025, https://developer.nvidia.com/blog/optimizing-transformer-based-diffusion-models-for-video-generation-with-nvidia-tensorrt/

  72. Run High-Performance LLM Inference Kernels from NVIDIA Using …, accessed July 17, 2025, https://developer.nvidia.com/blog/run-high-performance-llm-inference-kernels-from-nvidia-using-flashinfer/