Jump to content
 







Main menu
   


Navigation  



Main page
Contents
Current events
Random article
About Wikipedia
Contact us
Donate
 




Contribute  



Help
Learn to edit
Community portal
Recent changes
Upload file
 








Search  

































Create account

Log in
 









Create account
 Log in
 




Pages for logged out editors learn more  



Contributions
Talk
 



















Contents

   



(Top)
 


1 grad  





2 jit  





3 vmap  





4 pmap  





5 Libraries using JAX  





6 See also  





7 External links  





8 References  














Google JAX






العربية
Català
Русский

 

Edit links
 









Article
Talk
 

















Read
Edit
View history
 








Tools
   


Actions  



Read
Edit
View history
 




General  



What links here
Related changes
Upload file
Special pages
Permanent link
Page information
Cite this page
Get shortened URL
Download QR code
Wikidata item
 




Print/export  



Download as PDF
Printable version
 
















Appearance
   

 






From Wikipedia, the free encyclopedia
 


JAX
Developer(s)Google
Stable release

0.4.24[1] Edit this on Wikidata / 6 February 2024; 5 months ago (6 February 2024)

Repositorygithub.com/google/jax
Written inPython, C++
Operating systemLinux, macOS, Windows
PlatformPython, NumPy
Size9.0 MB
TypeMachine learning
LicenseApache 2.0
Websitejax.readthedocs.io/en/latest/ Edit this on Wikidata

Google JAX is a machine learning framework for transforming numerical functions to be used in Python.[2][3][4] It is described as bringing together a modified version of autograd[5] (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[6][7] The primary functions of JAX are:[2]

  1. grad: automatic differentiation
  2. jit: compilation
  3. vmap: auto-vectorization
  4. pmap: SPMD programming

grad[edit]

The code below demonstrates the grad function's automatic differentiation.

# imports
from jax import grad
import jax.numpy as jnp

# define the logistic function
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)

# evaluate the gradient of the logistic function at x = 1 
grad_log_out = grad_logistic(1.0)   
print(grad_log_out)

The final line should outputː

0.19661194

jit[edit]

The code below demonstrates the jit function's optimization through fusion.

# imports
from jax import jit
import jax.numpy as jnp

# define the cube function
def cube(x):
    return x * x * x

# generate data
x = jnp.ones((10000, 10000))

# create the jit version of the cube function
jit_cube = jit(cube)

# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)

The computation time for jit_cube (line no. 17) should be noticeably shorter than that for cube (line no. 16). Increasing the values on line no. 10, will increase the difference.

vmap[edit]

The code below demonstrates the vmap function's vectorization.

# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp

# define function
def grads(self, inputs):
    in_grad_partial = partial(self._net_grads, self._net_params)
    grad_vmap = vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads

The GIF on the right of this section illustrates the notion of vectorized addition.

Illustration video of vectorized addition

pmap[edit]

The code below demonstrates the pmap function's parallelization for matrix multiplication.

# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)

The final line should print the valuesː

[1.1566595 1.1805978]

Libraries using JAX[edit]

Several python libraries use JAX as a backend, including:

Some R libraries use JAX as a backend as well, including:

See also[edit]

External links[edit]

References[edit]

  1. ^ https://github.com/google/jax/releases/tag/jax-v0.4.24. {{cite web}}: Missing or empty |title= (help)
  • ^ a b Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, archived from the original on 2022-06-18, retrieved 2022-06-18
  • ^ Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1–3. Archived (PDF) from the original on 2022-06-21.
  • ^ "Using JAX to accelerate our research". www.deepmind.com. 4 December 2020. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  • ^ HIPS/autograd, Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton, 2024-03-27, retrieved 2024-03-28
  • ^ Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta". Business Insider. Archived from the original on 2022-06-21. Retrieved 2022-06-21.
  • ^ "Why is Google's JAX so popular?". Analytics India Magazine. 2022-04-25. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  • ^ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, retrieved 2022-07-29
  • ^ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, retrieved 2022-07-29
  • ^ Kidger, Patrick (2022-07-29), Equinox, retrieved 2022-07-29
  • ^ Optax, DeepMind, 2022-07-28, retrieved 2022-07-29
  • ^ RLax, DeepMind, 2022-07-29, retrieved 2022-07-29
  • ^ Jraph - A library for graph neural networks in jax., DeepMind, 2023-08-08, retrieved 2023-08-08
  • ^ "typing — Support for type hints". Python documentation. Retrieved 2023-08-08.
  • ^ jaxtyping, Google, 2023-08-08, retrieved 2023-08-08
  • ^ Jerzak, Connor (2023-10-01), fastrerandomize, retrieved 2023-10-03

  • Retrieved from "https://en.wikipedia.org/w/index.php?title=Google_JAX&oldid=1228905221"

    Categories: 
    Machine learning
    Google
    Hidden categories: 
    CS1 errors: missing title
    CS1 errors: bare URL
    Articles with short description
    Short description is different from Wikidata
     



    This page was last edited on 13 June 2024, at 21:25 (UTC).

    Text is available under the Creative Commons Attribution-ShareAlike License 4.0; additional terms may apply. By using this site, you agree to the Terms of Use and Privacy Policy. Wikipedia® is a registered trademark of the Wikimedia Foundation, Inc., a non-profit organization.



    Privacy policy

    About Wikipedia

    Disclaimers

    Contact Wikipedia

    Code of Conduct

    Developers

    Statistics

    Cookie statement

    Mobile view



    Wikimedia Foundation
    Powered by MediaWiki