Writing a tiled CUDA matmul đ
This year during my time at the Recurse Center, I worked through the various optimizations presented in Simon Boehm’s iconic post How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog. I approached these as a series of puzzles where I would read as little as possible, just the title or the first paragraphs describing the algorithm and then implement it in CUDA/C++. This was an exercise in writing & debugging CUDA along with implementing kernel code from a high level algorithm.
Here by matmul I actually mean a generalized matrix multiplication or gemm. Basically, a matrix multiply with some accoutrements thrown in.
here’s our typical function signature
__global__ void sgemm(int M, int N, int K, float alpha, const float *A,
const float *B, float beta, float *C)
The shapes are as , as and as .
Naive Kernel
Here we simply launch a thread for each output element in and each thread independently computes the dot product of it’s respective row from & column from .
Our launcher creates a grid of blocks, where each block is a group of 1024 threads referenced by a (32, 32) matrix. I won’t go into the details of this too much but section 5 of NVIDIA’s CUDA docs provides an explanation of the programming model.
dim3 gridDim(CEIL_DIV(M, 32), CEIL_DIV(N, 32));
dim3 blockDim(32, 32);
As you may expect, this approach is not very fast. There are many repeated reads from global memory (where all of A, B, and C live) and we don’t utilize the GPU’s memory bandwidth nor its strengths very well (more about this later).
__global__ void sgemm_naive(int M, int N, int K, float alpha, const float *A,
const float *B, float beta, float *C) {
// this is 32
const uint x = blockIdx.x * blockDim.x + threadIdx.x;
const uint y = blockIdx.y * blockDim.y + threadIdx.y;
if (x < M && y < N) {
float tmp = 0.0;
for (int i = 0; i < K; ++i) {
tmp += A[x * K + i] * B[i * N + y];
}
C[x * N + y] = alpha * tmp + beta * C[x * N + y];
}
}
This kernel’s performance is 553.3 GFLOPS
Global Memory Coalescing
With the naive kernel, we are not making good use of global memory (HBM) bandwidth. Each read from global memory by a thread is separate. How can we tweak our mapping of threads memory such that consecutive threads in a warp access consecutive memory locations such that GPU will coalesce (combine) global memory accesses. The compiler doesn’t issue any specific SASS/PTX for this. But the performance boost is significant. Yes, it’s subtle.
// block size is 32
const uint x = blockIdx.y * 32 + threadIdx.y;
const uint y = blockIdx.x * 32 + threadIdx.x;
This kernel’s performance is 4042.3 GFLOPS
Shared Memory Blocking
If we inspect the above figure showing global memory coalescing, 2 adjacent threads (by threadId
) will each issue instructions to read from the same memory locations of . Why fetch the same elements from global GPU memory multiple times?
We have access to faster on-chip memory called Shared Memory (SMEM). Each Streaming Multiprocessor has Shared Memory and it’s divided up between blocks of threads, such that threads in the same block can access the same section of shared memory. The amount of shared memory per SM (Streaming Multiprocessor) varies depending on the compute type. On our RTX 4090 this is 100KB per SM.
The shared memory blocking technique involves using the shared memory as a scratch space for the thread block. We map a thread block to a section of indices in matrix . Now note that this result section requires computing the respective dot products from the shared memory (which the entire threadblock worked together to load). Then we can slide the cached block to obtain the next section and repeat.

My annotations are in brown, base image from Kernel 3 in siboehm’s matmul post
Map the cRow, cCol
from the y, x
locations of the thread within the launch grid. We still launch one thread for each output element in . However within the confines of our threadblock we can directly map into our shared memory section with row, col
threadIdx.y, threadIdx.x
.
This requires some thinking about the matrices and indices of the elements that interest us. They have been denoted with the brown annotated text. This boils down to keeping track of the matrix shape within which we want to index into.
const uint cCol = blockIdx.x * blockDim.x + threadIdx.x;
const uint cRow = blockIdx.y * blockDim.y + threadIdx.y;
// The row and column within the shared memory chunk
const uint shmem_col = threadIdx.x;
const uint shmem_row = threadIdx.y; // this is because the shmem size == block size
__shared__ float sA[BLOCK_SIDE * BLOCK_SIDE];
__shared__ float sB[BLOCK_SIDE * BLOCK_SIDE];
for (int offset = 0; offset < K; offset += 32)
{
// This should be a coalesced load because the rightmost value shmem_col & cCol both
// reduce to threadIdx.x
sA[shmem_row * BLOCK_SIDE + shmem_col] = A[cRow * K + (offset + shmem_col)]; // the same row (x) of the result element (x, y)
// with the column offset by (offset + tid.y)
sB[shmem_row * BLOCK_SIDE + shmem_col] = B[(offset + shmem_row) * N + cCol]; // the same column of the result element (x, y)
...
// all the threads in the block compute the resulting dot products
This kernel’s performance is 7402 GFLOPS
1D Shared Memory Blocking
Now let’s make each thread do a bit more work. Right now each thread is responsible for computing just a single result in . Let’s give them some more responsibility. It can now compute a number of consecutive elements of the result, a 1D array of elements so to speak.

My annotations are on the right, left side is from Kernel 4 siboehm’s matmul post
Here I find it’s useful to visualize the problem as disparate steps, where we have a number of threads at our disposal. Each thread does not need to work solely on computing its result, i.e., the thread responsible for computing specific elements in may not be directly responsible for loading the contributing elements of and for its element.

The threads in a block we have to work with. Here we 512 threads compute 8 results each in a 64x64 section of . ( results)
We have shared memory and __syncthreads()
synchronization points, so we can share this work among different threads in a way that is convenient to us. For the elements local to a thread (the row of 8 elements a thread is responsible for in this example) we use registers, not shared memory. This is simply defined as float tmp[8] = {0.0};
.
Mapping our threads to elements in .
const uint a_inner_col = threadIdx.x % BK;
const uint a_inner_row = threadIdx.x / BK;
__shared__ float sA[BM * BK];
...
// moving the pointer to A to it's position in each threadblock is MUCH easier
// than fiddling with indices trying to keep things consistent within the offset loop.
A += blockIdx.y * BN * K;
// blockIdx.y is the row of the CUDA threadblock in the launch grid.
...
...
// within a loop incrementing our offset BK
for (int offset = 0; offset < K; offset += BK)
{
// the same row (x) of the result element (x, y)
// with the column offset by (offset + tid.y)
sA[a_inner_row * BK + a_inner_col] = A[a_inner_row * K + a_inner_col];
Mapping our threads to elements in .
const uint b_inner_col = threadIdx.x % BN; // 0..63
const uint b_inner_row = threadIdx.x / BN; // 0..8
__shared__ float sB[BK * BN];
...
B += blockIdx.x * BM;
...
...
// within a loop incrementing our offset BK
for (int offset = 0; offset < K; offset += BK)
{
// the same column of the result element (x, y)
// with the row offset by (offset + tid.x) * N
sB[b_inner_row * BN + b_inner_col] = B[(b_inner_row)*N + b_inner_col];
Mapping our threads to elements in .
const uint c_row = threadIdx.x / BM; // 0..8 since each thread handles 8 elements
const uint c_col = threadIdx.x % BM;
...
...
// within a loop incrementing our offset BK
for (int offset = 0; offset < K; offset += BK)
{
...
...
for (int t = 0; t < 8; t++)
{
for (int idx = 0; idx < BK; idx++)
{
tmp[t] +=
sA[((c_row * 8) + t) * BK + idx
] * sB[idx * BN + c_col];
}
}
__syncthreads();
}
for (int t = 0; t < 8; t++)
{
C[(c_row * 8 + t) * N + c_col] = alpha * tmp[t] + beta * C[(c_row * 8 + t) * N + c_col];
}
This kernel’s performance is 21808 GFLOPS
2D Shared Memory Blocking
Now that we have each thread computing consecutive values in a row, let’s increase the parallelism and have each thread compute values in a 2D grid to give even more work to the threads.
Outer loop (similar to the 1D case) Inner loop, each of the 3 squares in C is the territory of a single thread (drawn as 4x4 for simplicity)
To start let’s reason about how we would handle the loading and computation if we had specific tile sizes BM=64, BN=64, BK=8
and number of elements a thread computes TM=8, TN=8
thus 8 * 8 = 64 elements
. This BM * BN
section of C
will be computed by a thread block.

Just as before
For 64 threads (launched in a 1D block) these are mapped to elements of A like so:
const uint load_a_rows = threadIdx.x / BK // BK = 8
const uint load_a_cols = threadIdx.x % BK
const uint strideA = numThreadsBlocktile / BK;
// The total num of threads must be evenly divisible by BK
// so we can skip strideA complete rows while loading a tile.
assert(((numThreadsBlocktile) % BK == 0));
This in a loop that runs 8 times, and where the offset is updated by 64 each time will cover all the elements of our slice of shared memory.
const uint load_b_cols = threadIdx % BN // BN = 64
const uint strideB = numThreadsBlocktile / BN;
assert((numThreadsBlocktile) % BN == 0);
We increment the row 8 times. This should become a coalesced read from memory.
The interesting sections of the code are below. After loading we accumulate into the tmp
register.
// The threads position within a grid of (BM/TM) x (BN/TN)
const uint thread_col = threadIdx.x % (BN / TN);
const uint thread_row = threadIdx.x / (BN / TN);
A += blockIdx.y * BM * K;
B += blockIdx.x * BN;
C += (blockIdx.y * BM * N) + (blockIdx.x * BN);
for (int offset = 0; offset < K; offset += BK)
{
for (int shiftA = 0; shiftA < BM; shiftA += strideA)
{
sA[((load_A_row + shiftA) * BK) + load_A_col] = A[((load_A_row + shiftA) * K) + load_A_col];
}
for (int shiftB = 0; shiftB < BK; shiftB += strideB)
{
sB[(load_B_row + shiftB) * BN + load_B_col] = B[(load_B_row + shiftB) * N + load_B_col];
}
__syncthreads();
A += BK;
B += BK * N;
/*
Note the use of a register caches (sharedA_cache, sharedB_cache)
to reduce shared memory accesses.
For example sB is hit just 1x per TN, not for every TM.
*/
for (int idx = 0; idx < BK; ++idx)
{
for (int r = 0; r < TM; ++r)
{
sharedA_cache[r] = sA[(thread_row * TM + r) * BK + idx];
}
for (int c = 0; c < TN; ++c)
{
sharedB_cache[c] = sB[(idx * BN) + (thread_col * TN + c)];
}
for (int r = 0; r < TM; ++r)
{
for (int c = 0; c < TN; ++c)
{
tmp[r * TN + c] +=
sharedA_cache[r] * sharedB_cache[c];
}
}
}
__syncthreads();
}
This kernel’s performance is 37000 GFLOPS
Vectorized 2D Shared Memory Blocking
The next optimization involves vectorization. Which loads can we vectorize? Loading from B
into shared memory should be easy to vectorize as we’re already reading consecutive elements with consecutive threads. However with A
our setup is a bit more complicated.
The “trick” here is we transpose A
upon its load into shared memory.

We point our threads at A
very similarly to before (with the exception of each thread now loading a vector of 4 floats), but we send the elements to transposed locations in our shared memory locations for A
known as sA
.

The main differences in the kernel are with the loading from A
.
for (int shiftA = 0; shiftA < BM; shiftA += strideArows)
{
float4 tmp = reinterpret_cast<float4*>(&A[((load_A_row + shiftA) * K) + load_A_col * 4])[0];
sA[(load_A_col * 4 + 0) * BM + load_A_row] = tmp.x;
sA[(load_A_col * 4 + 1) * BM + load_A_row] = tmp.y;
sA[(load_A_col * 4 + 2) * BM + load_A_row] = tmp.z;
sA[(load_A_col * 4 + 3) * BM + load_A_row] = tmp.w;
}
We don’t make explicit changes to the loading from shared memory when we are computing products, as the compiler will issue lds.128
vectorized loads from shared memory in SASS.
This kernel’s performance is 39500 GFLOPS
Warptiling
Warptiling is our final optimization. While we use many of the patterns employed in the kernels we’ve built up until now, the kernel is not much more complicated to write. To be honest, I spent the most time reading into why warptiling improves performance. The reasons are subtle and this is a good foundational kernel from which we can dip into using the fast Tensor Cores. When we add a level of hierarchy that aligns with the division of threads into warps, we get some performance benefits. In CUDA threads are executed together in groups of consecutive called warps (usually 32). A threadblock of 128 threads would consist of 4 warps. The benefits are to register access patterns and allows for some instruction level parallelism.
Back to the gist of this post: how do we implement it? The main difference is that instead of spreading the threads out in a block, we tile the warps across the block. Thus the location of where, say threadIdx = 10
would be different as to where it was in the 2D tiling where the threads were just distributed all across the block.

Note the threads (dots). The purple are warp 1 (threadIdx 0 to 31). The pink are warp 2.
Once we have filled our shared memory blocks, we load into registers at the warp level. We load sections of shared memory at a time into thread-local registers we call fragments. These load instructions will involve overlap between the threads and what they access from shared memory, and should execute as SIMT.

These are some of the new variables we define in our kernel to perform the matmul.
/*
TM * TN are the number of threads in the thread block
thread_tile_rows * thread_tile_cols = 32 is our selected layout within the warp
*/
const uint warp_tile_cols = BN / (TN * thread_tile_cols);
const uint warp_tile_rows = BM / (TM * thread_tile_rows);
// shape of the warp tile
assert(BN % (TN * thread_tile_cols) == 0);
assert(BM % (TM * thread_tile_rows) == 0);
// Warp location among the warp tiles
const uint warp_row = warp_id / warp_tile_cols;
const uint warp_col = warp_id % warp_tile_cols;
// thread's x, y relative to warp
const uint thread_tile_row = thread_lane / thread_tile_cols;
const uint thread_tile_col = thread_lane % thread_tile_cols;
This article in the CUTLASS documentation is a good resource to read more about warptiling.
This kernel’s performance is 42000 GFLOPS
Reflections
I feel like this was a much slower method of learning to write CUDA matrix multiplications, abounding with frustrations and debugging. However, there was valuable insight in playing with these algorithms, drawing them (thanks excalidraw.com), and good practice in converting theoretical concepts to functional parallel code.
Next are faster, more modern matrix multiplications. On certain shapes and leveraging tensor cores it should be possible to eke out a win on specific cuBLAS implementations.