Skip to content
Aritra Roy Gosthipaty edited this page Oct 2, 2024 · 5 revisions

Welcome to the flux-jax wiki!

Today I learnt

device = "gpu" if jax.get_devices("gpu") else "cpu"

Jax dot product attention Torch scaled dot product attention
Batch, Target_Length, Num_Heads, Hidden_Dims Batch, Num_Heads, Target_Length, Hidden_Dims
Torch JAX
x.to(device) jax.to_device(device)

Interestingly enough there are three ways we can initialize a project with uv:

  1. An application -- uv init: This creates a python application, single source code file.
  2. A library -- uv init --lib: This creates a library, whose sole purpose is to be consumed by another application. Here uv creates a src directory, and add a demo api in the __init__ file.
  3. A packages application -- uv init --app --package: This is when you want to create a distributable application. It mixes app and library, just add a package entry point in the toml.

Offload to a cpu

a = jnp.array([1])
a.device
a = a.to_device(jax.devices("cpu")[0])
a.device
Clone this wiki locally