Introduction to Jax
In the last article, we shared the concept of Jax, a relatively new machine learning library from Google. It's more of an autograd library that differentiates between each native python and numpy
Composable conversion of python+numpy programs: differential, vectorization, JIT to GPU TPU, and much more". The library makes use of grad function transformations to convert functions to functions that return the gradient of the original function. JAX also provides a function transformation JIT for just-in-time compilation of existing functions, as well as VMAP and PMAP for vectorization and parallelization, respectively
JAX is a combination of Autograd and XLA, JAX itself is not a deep learning framework, it is a high-performance numerical computing library, but also a composable function transformation library for high-performance machine learning research. Deep learning is only part of the story, but you can port your own deep learning to Jax.
Since the advent of Google's JAX in late 2018, its popularity has been growing steadily. DeepMind announced in 202 that it was using JAX to accelerate its own research, and more and more projects from Google Brain and other projects are also using JAX. As Jax gets more and more popular, it seems that Jax is the next big deep learning frameworkAlthough Jax is not a neural network framework, with the development of Jax, many deep learning-related research can also be implemented using Jax.
In the previous article, we also shared the speed comparison between JAX and NUMPY, compared to NUMPY without JAX acceleration, its speed is far behind JAX, and in this issue we will use JAX to train the first machine learning model.
Use JAX to train your first machine learning model
Before using Jax, we need to install Jax, fortunately, Jax can be installed with pip, but Jax is currently not available on Windows platform, you can use Linux virtual machine to experience.
pip install jaxpip install autogradpip install numpypip install jaxlib
First of all, we need to install the relevant third-party libraries such as jax, and import the relevant third-party libraries.
import numpy as npimport jax.random as randomimport jaxfrom jax import numpy as jnpfrom jax import make_jaxprfrom jax import grad, jit, vmap, pmapimport matplotlib.pyplot as plt
Then we build a linear function of y=ax+b, where the parameter a is a slope of the line, b is the movement parameter of the line in the y-axis direction, and uses the random random function to generate a random x data, so that we get a completed y=ax+b linear function, we can use matplotlib to show the curve of this function.
key = random.prngkey(56)x = random.normal(key, shape=(128, 1))a = 3.0b = 5.0ys = (a*xs) +bplt.scatter(xs, ys)plt.xlabel("xs")plt.ylabel("ys")plt.title("linear f(x)")plt.show()
After running the above **, we get a linear function of y=ax+b.
With the above linear functions, we will build a linear model and use machine learning to ** this straight line.
def linear(theta, x): weight, bias = theta pred = x * weight + bias return pred
Then we define a linear function, this function also has 2 parameters, a weight (weight), a bias (bias), the purpose of training is to find a suitable weight and bias parameters, in order to ** the above linear function. Of course, we also need to set up a loss function so that the loss can gradually decrease during later training. Here the mean square deviation is used as the loss function to calculate the loss of ** value vs. true value.
def p_loss(theta, x, y): pred = linear(theta, x) loss = jnp.mean((y - pred)**2) return loss@jitdef update_step(theta, x, y, lr): loss, gradient = jax.value_and_grad(p_loss)(theta, x, y) updated_theta = theta - lr * gradient return updated_theta, loss
Then use jaxvalue and grad function to update the loss, lr parameter is the learning efficiency of the neural network, here we can randomly a relatively small value. With the above functions, we can train a machine learning model.
weight = 0.0bias = 0.0theta = jnp.array([weight, bias])epochs = 20000for item in range(epochs): theta, loss_p = update_step(theta, xs, ys, 1e-4) if item % 1000 == 0 and item != 0: print(f"item |loss ")
We initialize the weight and bias parameters, and use the for loop to train the neural network to make the loss smaller and smaller, here we print the loss parameters every 1000 steps.
item 1000 | loss 23.4526item 2000 | loss 15.4000item 3000 | loss 10.1152item 4000 | loss 6.6459item 5000 | loss 4.3678item 6000 | loss 2.8714item 7000 | loss 1.8883item 8000 | loss 1.2422item 9000 | loss 0.8174item 10000 | loss 0.5380item 11000 | loss 0.3543item 12000 | loss 0.2333item 13000 | loss 0.1538item 14000 | loss 0.1013item 15000 | loss 0.0668item 16000 | loss 0.0441item 17000 | loss 0.0291item 18000 | loss 0.0192item 19000 | loss 0.0127
From the above loss parameters, we can see that the loss of the model gradually decreases, indicating that the linear machine learning model we designed is effective. We can also print the output function of the model after 20,000 steps of training.
plt.scatter(xs, ys, label="true")plt.scatter(xs, linear(theta, xs), label="pred")plt.legend()plt.show()
It can be seen that with the training, the loss of the model gradually decreases, when the training is 20,000 steps, its Y=AX+B function almost coincides with the initial function value of the input, of course, you can also increase the training steps to make the loss shrink again.
Although Jax is not currently called a neural network model framework, with the addition of PyTorch, PaddlePaddle, and Mindspore related frameworks, the neural network framework controversy has intensified, and it is not certain that Google will develop Jax into the next generation of neural network frameworks.