Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Reductions using Coalesced Reads #12

Open
THargreaves opened this issue Dec 18, 2024 · 1 comment
Open

Faster Reductions using Coalesced Reads #12

THargreaves opened this issue Dec 18, 2024 · 1 comment

Comments

@THargreaves
Copy link

I wanted to share a quick demo of a way to make reductions over the columns of matrices substantially faster by ensuring reads from global memory are coalesced. This is very much a prototype where I've made some assumptions about the input matrix dimensions so I didn't have to worry about masking certain threads, but the idea would hold for general dimensions and be just as fast.

I plan to come back and formalise this approach when I have time, but I wanted to throw it out early to get feedback and start some discussion (as this sort of approach applies more generally to all kernel designs).

The core idea is that for simple reductions, the fundamental bottleneck is memory bandwidth, not compute.

On my RTX 4090 for example, the memory bandwidth is 1008 GB/s and the compute bandwidth is 82.6 TFLOPS. Focusing on the square D x D matrix case, a reduction must read and write D(D+1) * 4 bytes of Float32. Assuming op requires F FLOPs, the total amount of compute is roughly F * D(D - 1) assuming perfect parallelism. Assuming D is fairly large, the implied theoretical performance of these two bounds only cross when F ~ 328. Most reductions are going to be far from this point, and so the kernels are entirely memory bound.

For this reason, it is critical that all writes/reads involving global memory are coalesced so we can obtain the maximum throughput.

Going forwards I will focus on the case where _reduce_nd_by_thread is used (one thread per output element).

When reducing rows, coalescing occurs naturally and so near theoretically-optimal performance is achieved. For column reductions, reads are scattered leading to a 3–4x slow down (I'm actually surprised it isn't more!).

It's possible to get around this by using a bit of shared memory to temporarily store tiles of the input matrix that are read in a coalesced fashion. Below I provide a kernel column_reduction_kernel! which implements this in a naive way. By having 32 x 32 tiles, each associated with one warp, synchronisation is not required.

Admittedly, this naive approach puts a fair amount of pressure on shared memory. This is why the block size is 128 threads (else it wouldn't fit!). This can be relieved substantially by replacing the 32 x 32 tile with a 32 x K tile and having 32 / K threads work on each column. This is obviously not optimal in terms of compute-parallelism but this doesn't matter since the compute cost is negligible for all but the most expensive ops.

Even with this restriction of block size and shared memory pressure, the kernel achieves the same optimal speed as row reduction.

MRE and output below.

using AcceleratedKernels
using KernelAbstractions
using BenchmarkTools
using CUDA

using AcceleratedKernels: i16
using KernelAbstractions: synchronize

T = Float32
A_wide = CUDA.rand(T, 2^8, 2^18);
A_tall = CUDA.rand(T, 2^18, 2^8);
op = Base.add_sum
init = zero(T)
dst_wide = similar(A_wide, 1, size(A_wide, 2));
dst_tall = similar(A_tall, size(A_tall, 1), 1);

io = IOContext(stdout)
backend = get_backend(A_wide)
block_size = 1024

@kernel inbounds = true cpu = false function existing_kernel!(@Const(src), dst, op, init, dims)

    # One thread per output element, when there are more outer elements than in the reduced dim
    # e.g. reduce(+, rand(3, 1000), dims=1) => only 3 elements in the reduced dim
    src_sizes = size(src)
    src_strides = strides(src)
    dst_sizes = size(dst)
    dst_strides = strides(dst)

    output_size = length(dst)
    reduce_size = src_sizes[dims]

    ndims = length(src_sizes)

    N = @groupsize()[1]

    # NOTE: for many index calculations in this library, computation using zero-indexing leads to
    # fewer operations (also code is transpiled to CUDA / ROCm / oneAPI / Metal code which do zero
    # indexing). Internal calculations will be done using zero indexing except when actually
    # accessing memory. As with C, the lower bound is inclusive, the upper bound exclusive.

    # Group (block) and local (thread) indices
    iblock = @index(Group, Linear) - 0x1
    ithread = @index(Local, Linear) - 0x1

    tid = ithread + iblock * N

    # Each thread handles one output element
    tid = ithread + iblock * N
    if tid < output_size

        # # Sometimes slightly faster method using additional memory with
        # # output_idx = @private typeof(iblock) (ndims,)
        # tmp = tid
        # KernelAbstractions.Extras.@unroll for i in ndims:-1:1
        #     output_idx[i] = tmp ÷ dst_strides[i]
        #     tmp = tmp % dst_strides[i]
        # end
        # # Compute the base index in src (excluding the reduced axis)
        # input_base_idx = 0
        # KernelAbstractions.Extras.@unroll for i in 1:ndims
        #     i == dims && continue
        #     input_base_idx += output_idx[i] * src_strides[i]
        # end

        # Compute the base index in src (excluding the reduced axis)
        input_base_idx = typeof(ithread)(0)
        tmp = tid
        KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16
            if i != dims
                input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i]
            end
            tmp = tmp % dst_strides[i]
        end

        # Go over each element in the reduced dimension; this implementation assumes that there
        # are so many outer elements (each processed by an independent thread) that we afford to
        # loop sequentially over the reduced dimension (e.g. reduce(+, rand(3, 1000), dims=1))
        res = init
        for i in 0x0:reduce_size-0x1
            src_idx = input_base_idx + i * src_strides[dims]
            res = op(res, src[src_idx+0x1])
        end
        dst[tid+0x1] = res
    end
end

println("\n#### Existing wide reduction ####\n")
blocks = (size(dst_wide, 2) + block_size - 1) ÷ block_size
kernel! = existing_kernel!(backend, block_size)
# Validate
kernel!(A_wide, dst_wide, op, init, 1, ndrange=(blocks * block_size,))
synchronize(backend)
println("Valid: ", dst_wide  sum(A_wide, dims=1))
res = @benchmark begin
    kernel!(A_wide, dst_wide, op, init, 1, ndrange=(1024 * block_size))
    synchronize(backend)
end
show(io, "text/plain", res)

println("\n#### Existing tall reduction ####\n")
blocks = (size(dst_tall, 1) + block_size - 1) ÷ block_size
kernel!(A_tall, dst_tall, op, init, 2, ndrange=(blocks * block_size,))
synchronize(backend)
println("Valid: ", dst_tall  sum(A_tall, dims=2))
res = @benchmark begin
    kernel!(A_tall, dst_tall, op, init, 2, ndrange=(1024 * block_size))
    synchronize(backend)
end
show(io, "text/plain", res)

@kernel inbounds = true cpu = false function column_reduction_kernel!(@Const(src), dst, op, init)
    # Fixed parameters — chosen to meet shared memory constraints
    # TODO: generalise by having multiple threads compute each column
    BLOCK_SIZE = 128
    TILE_DIM = 32
    NUM_WARPS = 4

    # One 32x32 tile of shared memory per warp
    tiles = @localmem eltype(src) (TILE_DIM, TILE_DIM, NUM_WARPS)

    warp_id = (@index(Local, Linear) - 0x1) ÷ TILE_DIM     # Which warp (0-3)
    lane_id = (@index(Local, Linear) - 0x1) % TILE_DIM     # Position in warp (0-31)
    global_col = @index(Group, Linear) - 0x1               # Which block of columns

    src_height = size(src, 1)
    src_width = size(src, 2)

    # Each thread is responsible for one column in the output
    global_col_idx = global_col * BLOCK_SIZE + (@index(Local, Linear) - 0x1)

    if global_col_idx < src_width
        result = init

        # Process input in TILE_DIM x TILE_DIM tiles
        for tile_start = 0:TILE_DIM:src_height-1
            # Load tile into shared memory with coalesced reads (one column at a time)
            KernelAbstractions.Extras.@unroll for i = 0:TILE_DIM-1
                row_idx = tile_start + lane_id
                col_idx = global_col * BLOCK_SIZE + warp_id * TILE_DIM + i
                if row_idx < src_height
                    tiles[lane_id+1, i+1, warp_id+1] = src[row_idx+1, col_idx+1]
                end
            end
            # No sync needed since no communication between warps

            # Reduce along rows for this thread's column using shared memory
            KernelAbstractions.Extras.@unroll for i = 0:TILE_DIM-1
                result = op(result, tiles[i+1, lane_id+1, warp_id+1])
            end
        end

        # Write result to global memory (naturally coalesced)
        dst[1, global_col_idx+1] = result
    end
end

println("\n#### Coalesced column reduction ####\n")
col_block_size = 128
blocks = (size(A_wide, 2) + col_block_size - 1) ÷ col_block_size
kernel! = column_reduction_kernel!(backend, col_block_size)
# Validate
kernel!(A_wide, dst_wide, op, init, ndrange=(blocks * col_block_size,))
synchronize(backend)
println("Valid: ", dst_wide  sum(A_wide, dims=1))
res = @benchmark begin
    kernel!(A_wide, dst_wide, op, init, ndrange=(blocks * col_block_size))
    synchronize(backend)
end
show(io, "text/plain", res)
#### Existing wide reduction ####

Valid: true
BenchmarkTools.Trial: 4623 samples with 1 evaluation.
 Range (min … max):  1.069 ms … 1.435 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     1.079 ms             ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.080 ms ± 7.359 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

              ▁▂▄▅▆▇██▇▅▆▅▂▂▁                                
  ▂▂▂▂▃▃▄▄▅▆▆█████████████████▇▆▇▆▅▅▄▅▄▃▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂ ▄
  1.07 ms        Histogram: frequency by time        1.1 ms <

 Memory estimate: 1.62 KiB, allocs estimate: 58.
#### Existing tall reduction ####

Valid: true
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  335.978 μs … 419.454 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     342.958 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   342.409 μs ±   2.824 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

                ▂▁           ▁ ▁   ▁▁▄▄▆▆▅▇█▆▅▃▂                 
  ▁▁▁▂▂▃▄▄▅▄▅▆▇████▇▆▆▅▅▅▅▅▇██████▇█████████████▇▅▃▄▂▂▂▂▂▂▂▁▁▁▁ ▄
  336 μs           Histogram: frequency by time          348 μs <

 Memory estimate: 1.62 KiB, allocs estimate: 58.
#### Coalesced column reduction ####

Valid: true
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  306.203 μs … 369.769 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     308.687 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   309.001 μs ±   1.745 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

        ▁▂▅▇██▇▇▅▅▂▁  ▁▂▁▂▃▂▁             ▁ ▂ ▁                  
  ▁▁▂▄▅▆██████████████████████▇▇▆▅▄▅▄▅▇▆████████▇▆▅▄▃▃▂▂▂▁▂▁▁▁▁ ▅
  306 μs           Histogram: frequency by time          313 μs <

 Memory estimate: 1.52 KiB, allocs estimate: 57.

For multidimensional arrays (say 5 dims for the examples), the same ideas hold:

  • Reducing over "trailing" dimensions, e.g. dims = (3, 4, 5) is naturally coalesced
  • Reducing over "leading" dimensions, e.g. dims = (1, 2) can use this same approach
  • Reducing over "interior" dimensions, e.g. dims = (3, 4) can be handled by an adaptation of the current approach.

The only thing I'm not sure how to handle is dims = (2, 4). I think this might require two kernel calls, one reducing each dimension.

Eager to hear any thoughts on this topic!

@anicusan
Copy link
Member

Hi @THargreaves , thanks a lot for the time and energy put into this - it is very much a needed discussion.

I went through your comments and code (very nice!) - there are a few additional constraints we have in AcceleratedKernels when it comes to shared memory tiling:

  • We cannot assume a warp-size of 32 threads: on AMDs, the "wavefront" is 64 threads; on Intels (at least the UHD Graphics I was testing on) they were not specified and, without synchronization, reduction kernels hung the whole device. On CPUs - probably via the PoCL backend - technically the warp-size, in the sense of groups of threads guaranteed to work in lock-step, would be 1. Still, the warp-size is not yet exposed by KernelAbstractions.jl, though I'm hoping that will change in the foreseeable future.
  • I was playing around with tiling when I first started implementing an N-dimensional reduction (comment), but found the input-data size dependency problematic, for both memory use and kernel recompilation; still, I think this could be done with a bit more care:
    • While in your example you used Float32 elements, reductions could be done on arbitrarily-sized structs, so we'd need some way to statically determine the allocated shared memory size.
    • In principle this could be done, as we know the element size and block size at kernel compile-time, but we'd need a way to query the available shared memory from KernelAbstractions.jl to ensure what we allocate fits.
  • On the other hand, you're right that this has broad applicability to kernel-writing, and if we could craft some helpers / general interfaces for tiled shared memory access, it'd be amazing - and they'd be quite popular; I can imagine them being used heavily in dense linear algebra, which is something we'd like to extend AK with. There's some work in this direction in TiledCUDA and ThunderKittens that I know of, but they're C++ warts.

Again, thank you for going this deep into the codebase and prototyping this shared memory tiling approach. How would you see this work with the above constraints? Maybe we could start some work in KernelAbstractions for querying the shared memory size and perhaps the warp size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants