diff --git a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb new file mode 100644 index 000000000..5d43a0a09 --- /dev/null +++ b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb @@ -0,0 +1,682 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a74e284c", + "metadata": {}, + "source": [ + "# Getting started with JAX\n", + "\n", + "**Learning Objectives:**\n", + "* Practice defining and performing basic operations on JAX arrays.\n", + "* Understand JAX's functional programming paradigm (immutability).\n", + "* Use JAX's automatic differentiation capability (`jax.grad`).\n", + "* Learn how to train a linear regression from scratch with JAX." + ] + }, + { + "cell_type": "markdown", + "id": "94ebc9ac", + "metadata": {}, + "source": [ + "This notebook will cover basic JAX operations, automatic differentiation, and training a linear regression." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "418e43b5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "id": "0c2f952f", + "metadata": {}, + "source": [ + "## Operations on JAX Arrays" + ] + }, + { + "cell_type": "markdown", + "id": "90f4c211", + "metadata": {}, + "source": [ + "### JAX Arrays (Constants)" + ] + }, + { + "cell_type": "markdown", + "id": "19e840c7", + "metadata": {}, + "source": [ + "JAX arrays are immutable and are similar to `tf.constant` in TensorFlow. This means that once created, their values cannot be changed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca60260c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "x = jnp.array([2, 3, 4])\n", + "print(x)" + ] + }, + { + "cell_type": "markdown", + "id": "bafe3ec7", + "metadata": {}, + "source": [ + "### Point-wise operations" + ] + }, + { + "cell_type": "markdown", + "id": "617b5220", + "metadata": {}, + "source": [ + "JAX offers a comprehensive suite of point-wise operations, similar to what you'd find in NumPy or TensorFlow." + ] + }, + { + "cell_type": "markdown", + "id": "f513cf92", + "metadata": {}, + "source": [ + "**Exercise:** Create two JAX arrays `a = jnp.array([5, 3, 8])` and `b = jnp.array([3, -1, 2])`. Then, compute:\n", + "1. The sum of `a` and `b` using `jnp.add` and `+`.\n", + "2. The product of `a` and `b` using `jnp.multiply` and `*`.\n", + "3. The exponential of `a` using `jnp.exp`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e738e41", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "a = jnp.array([5, 3, 8])\n", + "b = jnp.array([3, -1, 2])\n", + "\n", + "sum_add = jnp.add(a, b)\n", + "sum_plus = a + b\n", + "print(f\"Sum using jnp.add: {sum_add}\")\n", + "print(f\"Sum using +: {sum_plus}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8d18a7a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "prod_multiply = jnp.multiply(a, b)\n", + "prod_star = a * b\n", + "print(f\"Product using jnp.multiply: {prod_multiply}\")\n", + "print(f\"Product using *: {prod_star}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec006262", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "exp_a = jnp.exp(a)\n", + "print(f\"Exponential of a: {exp_a}\")" + ] + }, + { + "cell_type": "markdown", + "id": "bb69a0b6", + "metadata": {}, + "source": [ + "### NumPy Interoperability" + ] + }, + { + "cell_type": "markdown", + "id": "5db4737a", + "metadata": {}, + "source": [ + "JAX operations can seamlessly accept native Python types (like lists and scalars) and NumPy arrays as inputs. Conversely, JAX arrays can be converted to NumPy arrays using the standard `np.array()` constructor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc63aefc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# numpy arrays\n", + "a_np = np.array([1, 2])\n", + "b_np = np.array([3, 4])\n", + "# jax sum\n", + "print(f\"Sum of numpy arrays: {jnp.add(a_np, b_np)}\")\n", + "\n", + "# jax arrays\n", + "a_jax = jnp.array([1, 2])\n", + "b_jax = jnp.array([3, 4])\n", + "# jax sum\n", + "print(f\"Sum of jax arrays: {jnp.add(a_jax, b_jax)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d30f69c3-94b0-4854-a3f4-263322a1b59a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Convert JAX array to NumPy array\n", + "a_jax_to_np = np.array(a_jax)\n", + "print(f\"JAX array converted to NumPy: {a_jax_to_np}, type: {type(a_jax_to_np)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "2bfa1e94", + "metadata": {}, + "source": [ + "## Linear Regression" + ] + }, + { + "cell_type": "markdown", + "id": "e72a7b29", + "metadata": {}, + "source": [ + "Now let's use JAX operations to implement linear regression. Later in the course, you'll see abstracted ways to do this using high-level libraries like Equinox or Flax." + ] + }, + { + "cell_type": "markdown", + "id": "143a6a8f", + "metadata": {}, + "source": [ + "### Toy Dataset" + ] + }, + { + "cell_type": "markdown", + "id": "304481a4", + "metadata": {}, + "source": [ + "We'll model the following function: $y = 2x + 10$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "389cfbf4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "X_train = jnp.array(range(10), dtype=jnp.float32)\n", + "Y_train = 2 * X_train + 10\n", + "print(f\"X_train: {X_train}\")\n", + "print(f\"Y_train: {Y_train}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2f41c5c", + "metadata": {}, + "source": [ + "Let's also create a test dataset to evaluate our models:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0801051b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "X_test = jnp.array(range(10, 20), dtype=jnp.float32)\n", + "Y_test = 2 * X_test + 10\n", + "print(f\"X_test: {X_test}\")\n", + "print(f\"Y_test: {Y_test}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1b060494", + "metadata": {}, + "source": [ + "#### Loss Function" + ] + }, + { + "cell_type": "markdown", + "id": "04fa8155", + "metadata": {}, + "source": [ + "A common baseline model is to predict the mean of the training target values. Let's calculate the Mean Squared Error (MSE) for this baseline on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd3b16b1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "y_mean = Y_test.mean()\n", + "\n", + "\n", + "def predict_mean(X):\n", + " return jnp.full_like(X, y_mean)\n", + "\n", + "\n", + "Y_hat_baseline = predict_mean(X_test)\n", + "baseline_errors = (Y_hat_baseline - Y_test) ** 2\n", + "baseline_loss = jnp.mean(baseline_errors)\n", + "print(f\"Baseline MSE Loss (predicting mean): {baseline_loss}\")" + ] + }, + { + "cell_type": "markdown", + "id": "920daa97", + "metadata": {}, + "source": [ + "Now, if $\\hat{Y}$ represents the vector containing our model's predictions when we use a linear regression model $\\hat{Y} = w_0X + w_1$, we can write a loss function taking as arguments the model parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da9b391f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def loss_mse(params, X, Y):\n", + " w0, w1 = params\n", + " Y_hat = w0 * X + w1\n", + " errors = (Y_hat - Y) ** 2\n", + " return jnp.mean(errors)" + ] + }, + { + "cell_type": "markdown", + "id": "999586e5", + "metadata": {}, + "source": [ + "### Gradient Function" + ] + }, + { + "cell_type": "markdown", + "id": "e1cb575b", + "metadata": {}, + "source": [ + "To use gradient descent, we need to take the partial derivatives of the loss function with respect to each of the weights. With JAX's automatic differentiation capability (`jax.grad`), we don't have to compute them manually! `jax.grad` transforms a function into a new function that computes its gradient. The `argnums` parameter specifies with respect to which argument(s) the gradient should be computed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3598ae6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# We want gradients with respect to params (arg 0)\n", + "grad_fn = jax.grad(loss_mse, argnums=0)\n", + "\n", + "\n", + "def compute_gradients(params, X, Y):\n", + " return grad_fn(params, X, Y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03cda5ca", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "initial_params = [0.0, 0.0] # w0, w1 as a list or tuple\n", + "dw0, dw1 = compute_gradients(initial_params, X_train, Y_train)\n", + "\n", + "print(f\"Initial d_w0: {dw0}\")\n", + "print(f\"Initial d_w1: {dw1}\")" + ] + }, + { + "cell_type": "markdown", + "id": "52d487dc", + "metadata": {}, + "source": [ + "### Training Loop" + ] + }, + { + "cell_type": "markdown", + "id": "458abedc", + "metadata": {}, + "source": [ + "Here we have a very simple training loop. Note we are ignoring best practices like batching and random weight initialization for simplicity." + ] + }, + { + "cell_type": "markdown", + "id": "da391285", + "metadata": {}, + "source": [ + "**Exercise:** Complete the `for` loop below to train a linear regression.\n", + "1. Use `compute_gradients` to compute `dw0` and `dw1`.\n", + "2. Update `w0` and `w1` using the computed gradients and the `LEARNING_RATE`. Remember JAX arrays are immutable, so you'll create new arrays for the updated parameters.\n", + "3. For every 100th step, compute and print the `loss` using the `loss_mse` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8edde47", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "STEPS = 1000\n", + "LEARNING_RATE = 0.02\n", + "MSG = \"STEP {step} - loss: {loss}, w0: {w0}, w1: {w1}\\n\"\n", + "\n", + "# Initialize parameters\n", + "params = [0.0, 0.0] # w0, w1\n", + "\n", + "for step in range(1, STEPS + 1):\n", + " grad_w0, grad_w1 = compute_gradients(params, X_train, Y_train)\n", + "\n", + " # Remember JAX arrays are immutable.\n", + " new_w0 = params[0] - LEARNING_RATE * grad_w0\n", + " new_w1 = params[1] - LEARNING_RATE * grad_w1\n", + " params = [new_w0, new_w1]\n", + "\n", + " if step % 100 == 0:\n", + " current_loss = loss_mse(params, X_train, Y_train)\n", + " print(\n", + " MSG.format(step=step, loss=current_loss, w0=params[0], w1=params[1])\n", + " )\n", + "\n", + "print(f\"Final parameters: w0={params[0]}, w1={params[1]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "493b2a51", + "metadata": {}, + "source": [ + "Now let's compare the test loss for this linear regression to the test loss from the baseline model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ad127ac", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "final_w0, final_w1 = params\n", + "test_loss = loss_mse(params, X_test, Y_test)\n", + "print(f\"Test MSE Loss (linear regression): {test_loss}\")\n", + "print(f\"Baseline MSE Loss (predicting mean): {baseline_loss}\")" + ] + }, + { + "cell_type": "markdown", + "id": "bafd7c4c", + "metadata": {}, + "source": [ + "This is indeed much better!" + ] + }, + { + "cell_type": "markdown", + "id": "8f27f078", + "metadata": {}, + "source": [ + "## Bonus" + ] + }, + { + "cell_type": "markdown", + "id": "4ab27042", + "metadata": {}, + "source": [ + "Try modelling a non-linear function such as: $y=xe^{-x^2}$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a9c4913", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "X = jnp.array(np.linspace(0, 2, 1000), dtype=jnp.float32)\n", + "Y = X * jnp.exp(-(X**2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6054012", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "plt.plot(X, Y)\n", + "plt.title(\"Non-linear function: y = x * exp(-x^2)\")\n", + "plt.xlabel(\"x\")\n", + "plt.ylabel(\"y\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "799eb02a", + "metadata": {}, + "source": [ + "To model this with a linear model, we need to engineer features. Let's create a function `make_features`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02375444", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def make_features(X):\n", + " f1 = jnp.ones_like(X) # Bias feature\n", + " f2 = X\n", + " f3 = X**2\n", + " f4 = X**3\n", + " f5 = jnp.sqrt(X)\n", + " f6 = jnp.exp(X)\n", + " # Stack them column-wise\n", + " return jnp.stack([f1, f2, f3, f4, f5, f6], axis=1)" + ] + }, + { + "cell_type": "markdown", + "id": "f63263c4", + "metadata": {}, + "source": [ + "We can reuse our `loss_mse` function, but we need a prediction function that works with matrix multiplication for features and weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4797934", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def predict(W, X):\n", + " return jnp.dot(X, W)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c26fb0c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def loss_mse(W, X, Y_true):\n", + " Y_hat = predict(W, X)\n", + " errors = (Y_hat - Y_true) ** 2\n", + " return jnp.mean(errors)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ae8337b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def compute_gradients(params_w, X_features, Y_true):\n", + " return jax.grad(loss_mse, argnums=0)(params_w, X_features, Y_true)" + ] + }, + { + "cell_type": "markdown", + "id": "67caf8e1", + "metadata": {}, + "source": [ + "Now, let's train our linear model on these engineered features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2036fda", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "STEPS = 2000\n", + "LEARNING_RATE = 0.02\n", + "\n", + "Xf = make_features(X)\n", + "n_features = Xf.shape[1]\n", + "\n", + "W = jnp.zeros(n_features)\n", + "\n", + "for step in range(1, STEPS + 1):\n", + " grads = compute_gradients(W, Xf, Y)\n", + " W = W - LEARNING_RATE * grads\n", + "\n", + " if step % 100 == 0:\n", + " current_loss = loss_mse(W, Xf, Y)\n", + " print(f\"Step: {step}, Loss: {current_loss}\")\n", + "\n", + "plt.plot(X, Y, label=\"Actual\")\n", + "plt.plot(X, predict(W, Xf), label=\"Predicted\")\n", + "plt.xlabel(\"x\")\n", + "plt.ylabel(\"y\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "47914b23", + "metadata": {}, + "source": [ + "Copyright 2025 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb1d6f81-9a3b-402c-9715-58ec40612d3b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "environment": { + "kernel": "conda-base-py", + "name": "workbench-notebooks.m133", + "type": "gcloud", + "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m133" + }, + "kernelspec": { + "display_name": "ASL Core", + "language": "python", + "name": "asl_core" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}