Test all funtions with jax.test_util.check_grads
to ensure differentiability
#438
Labels
array-api
Work is related to the Array API
enhancement
New feature or request
testing
Work is related to testing
Is Your Feature Request Related to a Problem? Please Describe
As we move towards porting stuff to the Array API to make things GPU enabled, we ultimately one them to be differentiable too. Rather than coming up with a solution that works for GPU, which then needs a massive rewrite for differentiability. It would be good to ensure functions are on route to differentiability at the time.
Describe the Solution You'd Like
No response
Describe Alternatives You've Considered
No response
Additional Context
No response
The text was updated successfully, but these errors were encountered: