-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster Reductions using Coalesced Reads #12
Comments
Hi @THargreaves , thanks a lot for the time and energy put into this - it is very much a needed discussion. I went through your comments and code (very nice!) - there are a few additional constraints we have in AcceleratedKernels when it comes to shared memory tiling:
Again, thank you for going this deep into the codebase and prototyping this shared memory tiling approach. How would you see this work with the above constraints? Maybe we could start some work in KernelAbstractions for querying the shared memory size and perhaps the warp size. |
I wanted to share a quick demo of a way to make reductions over the columns of matrices substantially faster by ensuring reads from global memory are coalesced. This is very much a prototype where I've made some assumptions about the input matrix dimensions so I didn't have to worry about masking certain threads, but the idea would hold for general dimensions and be just as fast.
I plan to come back and formalise this approach when I have time, but I wanted to throw it out early to get feedback and start some discussion (as this sort of approach applies more generally to all kernel designs).
The core idea is that for simple reductions, the fundamental bottleneck is memory bandwidth, not compute.
On my RTX 4090 for example, the memory bandwidth is 1008 GB/s and the compute bandwidth is 82.6 TFLOPS. Focusing on the square D x D matrix case, a reduction must read and write D(D+1) * 4 bytes of Float32. Assuming
op
requires F FLOPs, the total amount of compute is roughly F * D(D - 1) assuming perfect parallelism. Assuming D is fairly large, the implied theoretical performance of these two bounds only cross when F ~ 328. Most reductions are going to be far from this point, and so the kernels are entirely memory bound.For this reason, it is critical that all writes/reads involving global memory are coalesced so we can obtain the maximum throughput.
Going forwards I will focus on the case where
_reduce_nd_by_thread
is used (one thread per output element).When reducing rows, coalescing occurs naturally and so near theoretically-optimal performance is achieved. For column reductions, reads are scattered leading to a 3–4x slow down (I'm actually surprised it isn't more!).
It's possible to get around this by using a bit of shared memory to temporarily store tiles of the input matrix that are read in a coalesced fashion. Below I provide a kernel
column_reduction_kernel!
which implements this in a naive way. By having 32 x 32 tiles, each associated with one warp, synchronisation is not required.Admittedly, this naive approach puts a fair amount of pressure on shared memory. This is why the block size is 128 threads (else it wouldn't fit!). This can be relieved substantially by replacing the 32 x 32 tile with a 32 x K tile and having 32 / K threads work on each column. This is obviously not optimal in terms of compute-parallelism but this doesn't matter since the compute cost is negligible for all but the most expensive ops.
Even with this restriction of block size and shared memory pressure, the kernel achieves the same optimal speed as row reduction.
MRE and output below.
For multidimensional arrays (say 5 dims for the examples), the same ideas hold:
dims = (3, 4, 5)
is naturally coalesceddims = (1, 2)
can use this same approachdims = (3, 4)
can be handled by an adaptation of the current approach.The only thing I'm not sure how to handle is
dims = (2, 4)
. I think this might require two kernel calls, one reducing each dimension.Eager to hear any thoughts on this topic!
The text was updated successfully, but these errors were encountered: