Posted on

Table of Contents

logo

Preface

Last month I was in San Francisco for the PyTorch conference 2024. The last one I had been to was in New Orleans in 2022 when it was held as a satelite event for NeurIPS, ChatGPT was announced a few days prior, and it was the first PyTorch even after the move of PyTorch from Meta to the Linux Foundation.

A lot has happened in two years, the conference grew quite significantly and introduced a DL Compiler Mini-Summit where we heard about Triton from Philippe Tillet himself. The next day, FlexAttention was released as a new API to write custom attention mechanisms in PyTorch that lower to Triton and incur only a small performance penalty compared to handwritten Triton kernels.

This is really exciting as it lowers the latency between research and implementation, but the main question here is what is Triton and why should we care about it?

Triton

As Tillet puts it, before Triton the choice for GPU compilers was between CUDA and graph compilers. On the one hand, CUDA gives you flexibility at the cost of simplicity. On the other, graph compilers give you simplicity at the cost of flexibility. Triton aims to be a compiler with a high level of expressivity while keeping the code simpler than CUDA. It isn't a better CUDA but a simpler one, and therefore the code generated is at best as fast as CUDA.

The compiler was introduced in 2019 at MAPL' and the 1.0 released in 2021. Triton has three components

  • Triton: originally a C-like language, now a python-like one, for expressing tensor programs
  • Triton IR: an LLVM-based intermediate representation that represents programs as operations on tiles
  • Triton JIT: a just-in-time compiler and codegen backend to optimize operations on tiles and lower to LLVM-IR/MLIR

In practice, the LLVM-IR/MLIR generates PTX - the low-level instruction set used to generate executable binary code on NVIDIA GPUs.

triton_to_ptx

With those components, you can lower the barrier to entry for GPU programming. The syntax, being python-like, is accessible and you retain a good performance with increased productivity and not having to care about memory hierarchy and thread synchronization. How does that work?

How does an NVIDIA GPU work?

So how does a NVIDIA GPU typically work? You can think of a GPU as a grid with multiple units. Each unit is a streaming multiprocessor (SM). Each SM contains execution cores, and it manages the execution of threads. SMs are independent, which allows parallelism inside the GPU as multiple thread blocks can run concurrently.

Thread blocks? A thread block is a group of threads that can share data in shared memory and whose execution is synchronized to coordinate memory access. A given SM can have multiple concurrent thread blocks. Within each thread block, threads are grouped into warps 32 at a time. All threads in warp perform the same operation instruction on different data in lockstep (Single Instruction Multiple Threads (SIMT)). A given SM schedules the warps concurrently to maintain a high utilization of the GPU.

architecture

Execution cores? That's the CUDA cores and tensor cores. Broadly speaking, the CUDA cores are general-purpose and optimzed for single-precision floating-point operations whilst the tensor cores are designed to accelerate (mixed-precision) matrix operations.

A H100 (SXM5), for example has 132 SMs, with a maximum of 64 concurrent warps per SM. In total, it has 16,896 FP32 CUDA cores and 528 tensor cores.

GH100

From that picture, we can understand the memory hierarchy within a GPU.

  • Global memory (HBM3) is the main memory, it comes with high capacity and bandwidth but comparatively higher latency.
  • L2 cache is a memory cache shared across all SMs to reduce latency for global memory access.
  • L1 cache is a memory cache specific to each SM to, again, reduce latency for memory access.
  • Shared memory is on-chip memory shared amongst threads in the same thread block (in a given SM). It is high bandwidth and low latency.
  • Registers are private to each thread. It's the fastest memory.

Triton matters because it abstract away the complexity and manages memory hierarchy and thread synchronization for you. How?

Tiles

Tiles are the fundamental unit of computation in Triton and they map onto the architecture we just described.

A tile is a multi-dimensional array representing a block of data processed. A tile is mapped to a thread block, and each thread of the block handles a portion of the data. That way, tiles correspond to the data each kernel instance operates on.

Triton manages the data movement between the different memories and tries to minimize access to global memory. By doing so, it improves data locality and bandwidth utilization, and handles parallelization over tiles.

  • Locality: elements that are processed together are stored close to each other in memory. When data is accessed, the nearby elements are more likely to be in the fast shared memory, rather than in slower memory
  • Bandwidth: memory access is organized so that consecutive threads read or write consecutive memory locations (coalesced access). The data is loaded into shared memory once and then reused by multiple threads

In CUDA, you need to explicitly manage data placement with blockIdx and threadIdx, in Triton it's abstracted away and done automatically. By abstracting the hardware details, you can just focus on mathematical operations on tiles rather than memory and thread management. You can then use element-wise operations (arithmetic ops, logical ops), linear algebra ops (matrix multiplication), reduction ops ((arg)max, reduce, etc.), broadcasting, etc.

Simple example

Let's have a look at the simplest example from the documentation.

It's fairly straightforward: the pointers point to GPU memory, BLOCK_SIZE is a compile-time constant that indicates how many elements are processed in each block, the pid identifies the current block in the grid.

import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
               ):
    pid = tl.program_id(axis=0)  
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)

In the Triton IR, notice how the memory operations are abstracted. We have constants definitions, offsets calculation, masking, memory address calculation, loading data, and finally actual computation and storing of the results.

The operations from the Triton code are represented as tensors, and each operation corresponds to an instruction in the IR.

module {
  func public @add_kernel_0d1d2d3d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
    %c1024_i32 = arith.constant 1024 : i32
    %0 = tt.get_program_id {axis = 0 : i32} : i32
    %1 = arith.muli %0, %c1024_i32 : i32
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
    %3 = tt.splat %1 : (i32) -> tensor<1024xi32>
    %4 = arith.addi %3, %2 : tensor<1024xi32>
    %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32>
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
    %7 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32>
    %10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32>
    %13 = arith.addf %9, %12 : tensor<1024xf32>
    %14 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<1024x!tt.ptr<f32>>
    %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
    tt.store %15, %13, %6 : tensor<1024xf32>
    return
  }
}

In the LLVM IR we finally have the thread and block ids. The representation is still platform-idenpendent but there's explicit calls to intrinsics for reading thread and block IDs.

; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"

; Function Attrs: nounwind
define void @add_kernel_0d1d2d3d(float addrspace(1)* %0, float addrspace(1)* %1, float addrspace(1)* %2, i32 %3) local_unnamed_addr #0 {
  %5 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %6 = shl i32 %5, 2
  %7 = and i32 %6, 1020
  %8 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %9 = shl i32 %8, 10
  %10 = or i32 %7, %9
  %11 = or i32 %9, 512
  %12 = add i32 %11, %7
  %13 = icmp slt i32 %10, %3
  %14 = icmp slt i32 %12, %3
  %15 = sext i32 %10 to i64
  %16 = getelementptr float, float addrspace(1)* %0, i64 %15
  %17 = sext i32 %12 to i64
  %18 = getelementptr float, float addrspace(1)* %0, i64 %17
  %19 = tail call { i32, i32, i32, i32 } asm sideeffect "@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(float addrspace(1)* %16, i1 %13) #0
  %20 = extractvalue { i32, i32, i32, i32 } %19, 0
  %21 = bitcast i32 %20 to <1 x float>
  %22 = extractvalue { i32, i32, i32, i32 } %19, 1
  %23 = bitcast i32 %22 to <1 x float>
  %24 = extractvalue { i32, i32, i32, i32 } %19, 2
  %25 = bitcast i32 %24 to <1 x float>
  %26 = extractvalue { i32, i32, i32, i32 } %19, 3
  %27 = bitcast i32 %26 to <1 x float>
  %28 = tail call { i32, i32, i32, i32 } asm sideeffect "@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(float addrspace(1)* %18, i1 %14) #0
  %29 = extractvalue { i32, i32, i32, i32 } %28, 0
  %30 = bitcast i32 %29 to <1 x float>
  %31 = extractvalue { i32, i32, i32, i32 } %28, 1
  %32 = bitcast i32 %31 to <1 x float>
  %33 = extractvalue { i32, i32, i32, i32 } %28, 2
  %34 = bitcast i32 %33 to <1 x float>
  %35 = extractvalue { i32, i32, i32, i32 } %28, 3
  %36 = bitcast i32 %35 to <1 x float>
  %37 = getelementptr float, float addrspace(1)* %1, i64 %15
  %38 = getelementptr float, float addrspace(1)* %1, i64 %17
  %39 = tail call { i32, i32, i32, i32 } asm sideeffect "@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(float addrspace(1)* %37, i1 %13) #0
  %40 = extractvalue { i32, i32, i32, i32 } %39, 0
  %41 = bitcast i32 %40 to <1 x float>
  %42 = extractvalue { i32, i32, i32, i32 } %39, 1
  %43 = bitcast i32 %42 to <1 x float>
  %44 = extractvalue { i32, i32, i32, i32 } %39, 2
  %45 = bitcast i32 %44 to <1 x float>
  %46 = extractvalue { i32, i32, i32, i32 } %39, 3
  %47 = bitcast i32 %46 to <1 x float>
  %48 = tail call { i32, i32, i32, i32 } asm sideeffect "@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(float addrspace(1)* %38, i1 %14) #0
  %49 = extractvalue { i32, i32, i32, i32 } %48, 0
  %50 = bitcast i32 %49 to <1 x float>
  %51 = extractvalue { i32, i32, i32, i32 } %48, 1
  %52 = bitcast i32 %51 to <1 x float>
  %53 = extractvalue { i32, i32, i32, i32 } %48, 2
  %54 = bitcast i32 %53 to <1 x float>
  %55 = extractvalue { i32, i32, i32, i32 } %48, 3
  %56 = bitcast i32 %55 to <1 x float>
  %57 = fadd <1 x float> %21, %41
  %58 = fadd <1 x float> %23, %43
  %59 = fadd <1 x float> %25, %45
  %60 = fadd <1 x float> %27, %47
  %61 = fadd <1 x float> %30, %50
  %62 = fadd <1 x float> %32, %52
  %63 = fadd <1 x float> %34, %54
  %64 = fadd <1 x float> %36, %56
  %65 = getelementptr float, float addrspace(1)* %2, i64 %15
  %66 = getelementptr float, float addrspace(1)* %2, i64 %17
  %bc = bitcast <1 x float> %57 to <1 x i32>
  %67 = extractelement <1 x i32> %bc, i64 0
  %bc1 = bitcast <1 x float> %58 to <1 x i32>
  %68 = extractelement <1 x i32> %bc1, i64 0
  %bc2 = bitcast <1 x float> %59 to <1 x i32>
  %69 = extractelement <1 x i32> %bc2, i64 0
  %bc3 = bitcast <1 x float> %60 to <1 x i32>
  %70 = extractelement <1 x i32> %bc3, i64 0
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %67, i32 %68, i32 %69, i32 %70, float addrspace(1)* %65, i1 %13) #0
  %bc4 = bitcast <1 x float> %61 to <1 x i32>
  %71 = extractelement <1 x i32> %bc4, i64 0
  %bc5 = bitcast <1 x float> %62 to <1 x i32>
  %72 = extractelement <1 x i32> %bc5, i64 0
  %bc6 = bitcast <1 x float> %63 to <1 x i32>
  %73 = extractelement <1 x i32> %bc6, i64 0
  %bc7 = bitcast <1 x float> %64 to <1 x i32>
  %74 = extractelement <1 x i32> %bc7, i64 0
  tail call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %71, i32 %72, i32 %73, i32 %74, float addrspace(1)* %66, i1 %14) #0
  ret void
}

; Function Attrs: nofree nosync nounwind readnone speculatable
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1

; Function Attrs: nofree nosync nounwind readnone speculatable
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1

attributes #0 = { nounwind }
attributes #1 = { nofree nosync nounwind readnone speculatable }

!nvvm.annotations = !{!0, !1, !0}

!0 = !{void (float addrspace(1)*, float addrspace(1)*, float addrspace(1)*, i32)* @add_kernel_0d1d2d3d, !"kernel", i32 1}
!1 = !{void (float addrspace(1)*, float addrspace(1)*, float addrspace(1)*, i32)* @add_kernel_0d1d2d3d, !"maxntidx", i32 128}

And finally, the PTX code. At this point, the code targets specific hardware (here Ampere). The registers are declared and parameters are loaded into them, threads are indexed, and mememory address are calculated. You can see the data movement and computation.

//
// Generated by LLVM NVPTX Back-End
//

.version 7.5
.target sm_86
.address_size 64

	// .globl	add_kernel_0d1d2d3d

.visible .entry add_kernel_0d1d2d3d(
	.param .u64 add_kernel_0d1d2d3d_param_0,
	.param .u64 add_kernel_0d1d2d3d_param_1,
	.param .u64 add_kernel_0d1d2d3d_param_2,
	.param .u32 add_kernel_0d1d2d3d_param_3
)
.maxntid 128, 1, 1
{
	.reg .pred 	%p<7>;
	.reg .b32 	%r<33>;
	.reg .f32 	%f<25>;
	.reg .b64 	%rd<12>;

	ld.param.u64 	%rd7, [add_kernel_0d1d2d3d_param_0];
	ld.param.u64 	%rd8, [add_kernel_0d1d2d3d_param_1];
	mov.u32 	%r25, %tid.x;
	shl.b32 	%r26, %r25, 2;
	ld.param.u64 	%rd9, [add_kernel_0d1d2d3d_param_2];
	and.b32  	%r27, %r26, 1020;
	ld.param.u32 	%r28, [add_kernel_0d1d2d3d_param_3];
	mov.u32 	%r29, %ctaid.x;
	shl.b32 	%r30, %r29, 10;
	or.b32  	%r31, %r27, %r30;
	add.s32 	%r32, %r31, 512;
	setp.lt.s32 	%p1, %r31, %r28;
	setp.lt.s32 	%p2, %r32, %r28;
	mul.wide.s32 	%rd10, %r31, 4;
	add.s64 	%rd1, %rd7, %rd10;
	mul.wide.s32 	%rd11, %r32, 4;
	add.s64 	%rd2, %rd7, %rd11;
	@%p1 ld.global.v4.b32 { %r1, %r2, %r3, %r4 }, [ %rd1 + 0 ];
	mov.b32 	%f1, %r1;
	mov.b32 	%f2, %r2;
	mov.b32 	%f3, %r3;
	mov.b32 	%f4, %r4;
	@%p2 ld.global.v4.b32 { %r5, %r6, %r7, %r8 }, [ %rd2 + 0 ];
	mov.b32 	%f5, %r5;
	mov.b32 	%f6, %r6;
	mov.b32 	%f7, %r7;
	mov.b32 	%f8, %r8;
	add.s64 	%rd3, %rd8, %rd10;
	add.s64 	%rd4, %rd8, %rd11;
	@%p1 ld.global.v4.b32 { %r9, %r10, %r11, %r12 }, [ %rd3 + 0 ];
	mov.b32 	%f9, %r9;
	mov.b32 	%f10, %r10;
	mov.b32 	%f11, %r11;
	mov.b32 	%f12, %r12;
	@%p2 ld.global.v4.b32 { %r13, %r14, %r15, %r16 }, [ %rd4 + 0 ];
	mov.b32 	%f13, %r13;
	mov.b32 	%f14, %r14;
	mov.b32 	%f15, %r15;
	mov.b32 	%f16, %r16;
	add.f32 	%f17, %f1, %f9;
	add.f32 	%f18, %f2, %f10;
	add.f32 	%f19, %f3, %f11;
	add.f32 	%f20, %f4, %f12;
	add.f32 	%f21, %f5, %f13;
	add.f32 	%f22, %f6, %f14;
	add.f32 	%f23, %f7, %f15;
	add.f32 	%f24, %f8, %f16;
	add.s64 	%rd5, %rd9, %rd10;
	add.s64 	%rd6, %rd9, %rd11;
	mov.b32 	%r17, %f17;
	mov.b32 	%r18, %f18;
	mov.b32 	%r19, %f19;
	mov.b32 	%r20, %f20;
	@%p1 st.global.v4.b32 [ %rd5 + 0 ], { %r17, %r18, %r19, %r20 };
	mov.b32 	%r21, %f21;
	mov.b32 	%r22, %f22;
	mov.b32 	%r23, %f23;
	mov.b32 	%r24, %f24;
	@%p2 st.global.v4.b32 [ %rd6 + 0 ], { %r21, %r22, %r23, %r24 };
	ret;

}

Conclusion

Triton is great because it bridges the gap beween high performance code and (relative) ease of use. You don't need to be an expert in CUDA to write good enough code as all the details you are likely to get wrong are abstracted away. The key here is that it allows researchers to focus on research rather than implementation. Realistically it won't replace CUDA but it makes life easier.