Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow convolution, many memory warnings #25461

Open
apasarkar opened this issue Dec 13, 2024 · 2 comments
Open

Slow convolution, many memory warnings #25461

apasarkar opened this issue Dec 13, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@apasarkar
Copy link

Description

Hi, I wrote some code to run a spatial filter on some frames of data. I'm basically convolving a batch of images with a ~30 x 30 image kernel. I'd like to take advantage of jit + vmap here (see code below) so that this code can run as fast as possible on GPU. When I do this, the execution is extremely slow (it takes 20 seconds to run the code below). Several warnings show up about memory allocation (included below). Based on the warnings it looks like under the hood a lot of time is being spent trying to figure out the best strategy for running the convolution - not sure though. Any and all help on this greatly appreciated!

Code:

import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, vmap
import jax.lax as lax

def _convolution_image_filter(img: np.ndarray, kernel: np.ndarray):
    """
    Filter img with kernel
    Args:
        img (np.ndarray): Shape (fov dim 1, fov dim 2). Image to be filtered
        kernel (np.ndarray): Shape (k1, k1). Kernel for filtering
    Returns:
        filtered_img (np.ndarray): Shape (fov dim 1, fov dim 2).
    """
    img_padded = jnp.pad(img,
                         (((kernel.shape[0]) // 2, (kernel.shape[0]) // 2),
                          ((kernel.shape[1]) // 2, (kernel.shape[1]) // 2)),
                         mode='reflect')
    filtered_frame = jax.scipy.signal.convolve(img_padded, kernel, mode="valid")
    return filtered_frame

convolution_image_filter = jit(_convolution_image_filter)
convolution_image_filter_batch = jit(vmap(_convolution_image_filter, in_axes=(0, None)))

kernel = np.ones((30, 30))
data = np.random.rand(100, 500, 1400)

output = convolution_image_filter_batch(data, kernel)

Various Warnings/Messages:

2024-12-13 12:26:57.046981: E external[/xla/xla/service/slow_operation_alarm.cc:65](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=64)] Trying algorithm eng28{k2=3,k3=0} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-12-13 12:26:57.597376: E external[/xla/xla/service/slow_operation_alarm.cc:133](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=132)] The operation took 1.55049392s
Trying algorithm eng28{k2=3,k3=0} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-12-13 12:26:58.597512: E external[/xla/xla/service/slow_operation_alarm.cc:65](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=64)] Trying algorithm eng1{k2=4,k3=0} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-12-13 12:26:59.306847: E external[/xla/xla/service/slow_operation_alarm.cc:133](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=132)] The operation took 1.70940755s
Trying algorithm eng1{k2=4,k3=0} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-12-13 12:26:59.306914: W external[/xla/xla/tsl/framework/bfc_allocator.cc:306](http://localhost:2001/xla/xla/tsl/framework/bfc_allocator.cc#line=305)] Allocator (GPU_0_bfc) ran out of memory trying to allocate 235.35GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-12-13 12:27:00.307021: E external[/xla/xla/service/slow_operation_alarm.cc:65](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=64)] Trying algorithm eng28{k2=0,k3=0} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-12-13 12:27:03.606746: E external[/xla/xla/service/slow_operation_alarm.cc:133](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=132)] The operation took 4.299796437s
Trying algorithm eng28{k2=0,k3=0} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-12-13 12:27:04.606890: E external[/xla/xla/service/slow_operation_alarm.cc:65](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=64)] Trying algorithm eng28{k2=1,k3=0} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-12-13 12:27:10.442216: E external[/xla/xla/service/slow_operation_alarm.cc:133](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=132)] The operation took 6.835400164s
Trying algorithm eng28{k2=1,k3=0} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-12-13 12:27:11.442351: E external[/xla/xla/service/slow_operation_alarm.cc:65](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=64)] Trying algorithm eng4{} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
E1213 12:27:12.556075  141765 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E1213 12:27:13.664756  141765 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2024-12-13 12:27:13.675042: E external[/xla/xla/service/slow_operation_alarm.cc:133](http://localhost:2001/xla/xla/service/slow_operation_alarm.cc#line=132)] The operation took 3.232765972s
Trying algorithm eng4{} for conv (f32[100,1,501,1401]{3,2,1,0}, u8[0]{0}) custom-call(f32[100,1,530,1430]{3,2,1,0}, f32[1,1,30,30]{3,2,1,0}), window={size=30x30}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-12-13 12:27:13.685184: W external[/xla/xla/service/gpu/nvptx_compiler.cc:893](http://localhost:2001/xla/xla/service/gpu/nvptx_compiler.cc#line=892)] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.11.9 (main, Apr 6 2024, 17:59:24) [GCC 9.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='panda', release='5.4.0-200-generic', version='#220-Ubuntu SMP Fri Sep 27 13:19:16 UTC 2024', machine='x86_64')

@apasarkar apasarkar added the bug Something isn't working label Dec 13, 2024
@apasarkar apasarkar changed the title Convolution issue Slow convolution, many memory warnings Dec 13, 2024
@dfm
Copy link
Collaborator

dfm commented Dec 13, 2024

Thanks for the report! Some comments about this case: Can you comment on where your 20 second number is coming from? I can't exactly reproduce it locally, although compilation of this function is quite slow. If you're not already familiar with it, you might want to check out the JAX docs about microbenchmarking to work out where you're seeing a bottleneck. Running your test code on Colab, I find that compilation takes about 5 seconds, then the runtime after it is compiled is about 300-400ms. It would be interesting to know if you consider that extremely slow for your use cases.

If these compile times are a blocker for you, you might be ways to control XLA's autotuning behavior to trade off compile time vs. runtime performance, but I must admit I don't know how to do that off the top of my head!

@apasarkar
Copy link
Author

apasarkar commented Dec 15, 2024

@dfm Thank you for the quick response!

I'm inferring the 20 second number by running the following code snippet, following the jax benchmarking guide you linked:

kernel = np.ones((30, 30))
data = np.random.rand(100, 500, 1400)
data_gpu = jax.device_put(data)
%time output = convolution_image_filter_batch(data_gpu, kernel).block_until_ready()
%time output = convolution_image_filter_batch(data_gpu, kernel).block_until_ready()

The difference in wall times between the above executions is >20 seconds (24.1s vs. 92.9ms).

One thing I'll add is that if you increase the number of frames (i.e. instead of data = np.random(100, 500, 1400), you do data = np.random(1000, 500, 1400), so you are processing 1,000 frames instead of 100), the compilation time goes up drastically, to ~3.5 minutes!

Perhaps this suggests that JAX and/or XLA is struggling to find the right convolution algorithm under the hood?

I guess a separate question is whether the algorithm that it ends up picking is the fastest one for execution on the GPU.

Happy to do some digging into any/all of the above, just let me know what you think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants