generated from SauravMaheshkar/python-template
-
Notifications
You must be signed in to change notification settings - Fork 2
Home
Aritra Roy Gosthipaty edited this page Oct 2, 2024
·
5 revisions
Welcome to the flux-jax wiki!
device = "gpu" if jax.get_devices("gpu") else "cpu"
Jax dot product attention | Torch scaled dot product attention |
---|---|
Batch, Target_Length, Num_Heads, Hidden_Dims |
Batch, Num_Heads, Target_Length, Hidden_Dims |
Torch | JAX |
---|---|
x.to(device) |
jax.to_device(device) |
Interestingly enough there are three ways we can initialize a project with uv
:
- An application --
uv init
: This creates a python application, single source code file. - A library --
uv init --lib
: This creates a library, whose sole purpose is to be consumed by another application. Hereuv
creates asrc
directory, and add a demo api in the__init__
file. - A packages application --
uv init --app --package
: This is when you want to create a distributable application. It mixes app and library, just add a package entry point in thetoml
.
a = jnp.array([1])
a.device
a = a.to_device(jax.devices("cpu")[0])
a.device