-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slow convolution, many memory warnings #25461
Comments
Thanks for the report! Some comments about this case: Can you comment on where your 20 second number is coming from? I can't exactly reproduce it locally, although compilation of this function is quite slow. If you're not already familiar with it, you might want to check out the JAX docs about microbenchmarking to work out where you're seeing a bottleneck. Running your test code on Colab, I find that compilation takes about 5 seconds, then the runtime after it is compiled is about 300-400ms. It would be interesting to know if you consider that extremely slow for your use cases. If these compile times are a blocker for you, you might be ways to control XLA's autotuning behavior to trade off compile time vs. runtime performance, but I must admit I don't know how to do that off the top of my head! |
@dfm Thank you for the quick response! I'm inferring the 20 second number by running the following code snippet, following the jax benchmarking guide you linked:
The difference in wall times between the above executions is >20 seconds (24.1s vs. 92.9ms). One thing I'll add is that if you increase the number of frames (i.e. instead of data = np.random(100, 500, 1400), you do data = np.random(1000, 500, 1400), so you are processing 1,000 frames instead of 100), the compilation time goes up drastically, to ~3.5 minutes! Perhaps this suggests that JAX and/or XLA is struggling to find the right convolution algorithm under the hood? I guess a separate question is whether the algorithm that it ends up picking is the fastest one for execution on the GPU. Happy to do some digging into any/all of the above, just let me know what you think! |
Description
Hi, I wrote some code to run a spatial filter on some frames of data. I'm basically convolving a batch of images with a ~30 x 30 image kernel. I'd like to take advantage of jit + vmap here (see code below) so that this code can run as fast as possible on GPU. When I do this, the execution is extremely slow (it takes 20 seconds to run the code below). Several warnings show up about memory allocation (included below). Based on the warnings it looks like under the hood a lot of time is being spent trying to figure out the best strategy for running the convolution - not sure though. Any and all help on this greatly appreciated!
Code:
Various Warnings/Messages:
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.11.9 (main, Apr 6 2024, 17:59:24) [GCC 9.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='panda', release='5.4.0-200-generic', version='#220-Ubuntu SMP Fri Sep 27 13:19:16 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered: