-
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Next lecture for JAX conversion #50
Comments
Hi @jstac and @Smit-create, I can have a first try and send it to @Smit-create for review to see if it needs further improvement :) |
Thanks @HumphreyYang . @Smit-create , you might like to try at the same time or closely coordinate with @HumphreyYang . I'm sure there will be plenty of challenges. |
Thanks @jstac, @HumphreyYang. I'll also have a look into it at the same time. |
Yeah, I see. I was trying to write the JAX code but that fails in quantecon's root-finding using This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'jaxlib.xla_extension.CompiledFunction'>
- argument 2: Cannot determine Numba type of <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
- argument 3: Cannot determine Numba type of <class 'tuple'> |
@Smit-create Same here. I also attempted to convert the last exercise, but |
Thanks for making a start @Smit-create @HumphreyYang Yes, the first challenge is that you need to find JAX equivalents for root finding and interpolation. Regarding root finding, you could try https://jaxopt.github.io/stable/root_finding.html or the methods in https://python.quantecon.org/newton_method.html For linear interpolation in one dimension there is https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.interp.html You will need to be careful about what happens outside the grid points used for interpolation. For some lectures we will need 2D interpolation. This might be harder. It is mentioned in https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.ndimage.map_coordinates.html For 2D I noticed https://github.com/adam-coogan/jaxinterp2d/ |
EDIT transferred from intermediate lectures. The discussion is still relevant.
Perhaps this lecture is a good candidate for attempting to port to JAX: https://python.quantecon.org/ifp.html
I think it will be challenging because there is both linear interpolation and root finding.
I'm not sure how well these can be done with JAX, and it's possible that we cannot beat the Numba versions. But it would be interesting to find out.
CC @Smit-create @HumphreyYang
The text was updated successfully, but these errors were encountered: