Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Next lecture for JAX conversion #50

Open
jstac opened this issue Jan 18, 2023 · 6 comments
Open

Next lecture for JAX conversion #50

jstac opened this issue Jan 18, 2023 · 6 comments

Comments

@jstac
Copy link
Contributor

jstac commented Jan 18, 2023

EDIT transferred from intermediate lectures. The discussion is still relevant.

Perhaps this lecture is a good candidate for attempting to port to JAX: https://python.quantecon.org/ifp.html

I think it will be challenging because there is both linear interpolation and root finding.

I'm not sure how well these can be done with JAX, and it's possible that we cannot beat the Numba versions. But it would be interesting to find out.

CC @Smit-create @HumphreyYang

@HumphreyYang
Copy link
Collaborator

Hi @jstac and @Smit-create,

I can have a first try and send it to @Smit-create for review to see if it needs further improvement :)

@jstac
Copy link
Contributor Author

jstac commented Jan 18, 2023

Thanks @HumphreyYang . @Smit-create , you might like to try at the same time or closely coordinate with @HumphreyYang . I'm sure there will be plenty of challenges.

@Smit-create
Copy link
Member

Thanks @jstac, @HumphreyYang. I'll also have a look into it at the same time.

@Smit-create
Copy link
Member

I think it will be challenging because there is both linear interpolation and root finding.

Yeah, I see. I was trying to write the JAX code but that fails in quantecon's root-finding using brentq. Since we make all the functions jax.jit, numba fails as it couldn't detect the types.

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'jaxlib.xla_extension.CompiledFunction'>
- argument 2: Cannot determine Numba type of <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
- argument 3: Cannot determine Numba type of <class 'tuple'>

@HumphreyYang
Copy link
Collaborator

HumphreyYang commented Jan 19, 2023

I think it will be challenging because there is both linear interpolation and root finding.

Yeah, I see. I was trying to write the JAX code but that fails in quantecon's root-finding using brentq. Since we make all the functions jax.jit, numba fails as it couldn't detect the types.

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'jaxlib.xla_extension.CompiledFunction'>
- argument 2: Cannot determine Numba type of <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
- argument 3: Cannot determine Numba type of <class 'tuple'>

@Smit-create Same here. brentq needs a function as input, and it is at the center of this implementation. Nonetheless, any JAX jitted function will not be able to be passed into it unless it is not JAX compiled. Without jax.jit, there is little room to improve -- there may be some work we can do to eliminate the for loops, but the trade-off will be losing the speed-up of the compiled function unless we have a JAX version of the function.

I also attempted to convert the last exercise, but IFP class defined using the extended pytree blocked the implementation. nametuple may get around with it till some point, but interp function limits the possibility to parallel the computation in JAX.

@jstac
Copy link
Contributor Author

jstac commented Jan 19, 2023

Thanks for making a start @Smit-create @HumphreyYang

Yes, the first challenge is that you need to find JAX equivalents for root finding and interpolation.

Regarding root finding, you could try https://jaxopt.github.io/stable/root_finding.html or the methods in https://python.quantecon.org/newton_method.html

For linear interpolation in one dimension there is https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.interp.html

You will need to be careful about what happens outside the grid points used for interpolation.

For some lectures we will need 2D interpolation. This might be harder. It is mentioned in https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.ndimage.map_coordinates.html

For 2D I noticed https://github.com/adam-coogan/jaxinterp2d/

@jstac jstac transferred this issue from QuantEcon/lecture-python.myst May 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants