Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check autodiff and batching support for healpix_fft_cuda primitive and add if needed #237

Open
matt-graham opened this issue Oct 9, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@matt-graham
Copy link
Collaborator

I think the primitive added in #204 may not support automatic differentiation and batching transforms as we did not define Jacobian vector product and transpose operations (for autodiff) and a batcher (for vmap support). We should verify if this is the case and add implementations if necessary.

@matt-graham matt-graham added the enhancement New feature or request label Oct 9, 2024
@ASKabalan
Copy link
Collaborator

Hello @matt-graham
I think that the gradient is straight forward since the FFT is linear
The adjoint autograd for the forward pass has to be a spectral folding followed by FFTs in the reverse order correct?
For vmap this might be a bit challenging since there is only one cudastream provided by XLA , and from experience forking a stream has some overhead to it. Let me think about it.
Some of the JAX guys suggest I use Pallas instead of cuda which solves the latter issue don't know what you guys think about it.

@jasonmcewen
Copy link
Contributor

jasonmcewen commented Oct 25, 2024

Yes, as you say @ASKabalan , since operations are linear, gradients can be computed via inverse transforms. Precisely how to do this is outlined in our paper here in Section 5. That could be a good approach here. This is also how we implemented the VJPs for the C wrappers that we added to s2fft, although that trades off accurate between the forward passes and VJPs when adding in HEALPix interations (which we should do soon).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants