From 035a52fa831db6dd40eddde385e5821c2f83ffc0 Mon Sep 17 00:00:00 2001 From: Takumi Ohyama Date: Wed, 17 Sep 2025 19:25:50 +0000 Subject: [PATCH 1/6] Add intro to JAX --- .../solutions/1_introduction_to_jax.ipynb | 691 ++++++++++++++++++ 1 file changed, 691 insertions(+) create mode 100644 asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb diff --git a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb new file mode 100644 index 000000000..2d2586232 --- /dev/null +++ b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb @@ -0,0 +1,691 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a74e284c", + "metadata": {}, + "source": [ + "# Getting started with JAX\n", + "\n", + "**Learning Objectives:**\n", + "* Practice defining and performing basic operations on JAX arrays.\n", + "* Understand JAX's functional programming paradigm (immutability).\n", + "* Use JAX's automatic differentiation capability (`jax.grad`).\n", + "* Learn how to train a linear regression from scratch with JAX." + ] + }, + { + "cell_type": "markdown", + "id": "94ebc9ac", + "metadata": {}, + "source": [ + "This notebook will cover basic JAX operations, automatic differentiation, and training a linear regression." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "418e43b5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "id": "0c2f952f", + "metadata": {}, + "source": [ + "## Operations on JAX Arrays" + ] + }, + { + "cell_type": "markdown", + "id": "90f4c211", + "metadata": {}, + "source": [ + "### JAX Arrays (Constants)" + ] + }, + { + "cell_type": "markdown", + "id": "19e840c7", + "metadata": {}, + "source": [ + "JAX arrays are immutable and are similar to `tf.constant` in TensorFlow. This means that once created, their values cannot be changed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca60260c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "x = jnp.array([2, 3, 4])\n", + "print(x)" + ] + }, + { + "cell_type": "markdown", + "id": "bafe3ec7", + "metadata": {}, + "source": [ + "### Point-wise operations" + ] + }, + { + "cell_type": "markdown", + "id": "617b5220", + "metadata": {}, + "source": [ + "JAX offers a comprehensive suite of point-wise operations, similar to what you'd find in NumPy or TensorFlow." + ] + }, + { + "cell_type": "markdown", + "id": "f513cf92", + "metadata": {}, + "source": [ + "**Exercise:** Create two JAX arrays `a = jnp.array([5, 3, 8])` and `b = jnp.array([3, -1, 2])`. Then, compute:\n", + "1. The sum of `a` and `b` using `jnp.add` and `+`.\n", + "2. The product of `a` and `b` using `jnp.multiply` and `*`.\n", + "3. The exponential of `a` using `jnp.exp`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e738e41", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "a = jnp.array([5, 3, 8])\n", + "b = jnp.array([3, -1, 2])\n", + "\n", + "sum_add = jnp.add(a, b)\n", + "sum_plus = a + b\n", + "print(f\"Sum using jnp.add: {sum_add}\")\n", + "print(f\"Sum using +: {sum_plus}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8d18a7a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "a = jnp.array([5, 3, 8])\n", + "b = jnp.array([3, -1, 2])\n", + "\n", + "prod_multiply = jnp.multiply(a, b)\n", + "prod_star = a * b\n", + "print(f\"Product using jnp.multiply: {prod_multiply}\")\n", + "print(f\"Product using *: {prod_star}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec006262", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "a = jnp.array([5, 3, 8])\n", + "\n", + "exp_a = jnp.exp(a)\n", + "print(f\"Exponential of a: {exp_a}\")" + ] + }, + { + "cell_type": "markdown", + "id": "bb69a0b6", + "metadata": {}, + "source": [ + "### NumPy Interoperability" + ] + }, + { + "cell_type": "markdown", + "id": "5db4737a", + "metadata": {}, + "source": [ + "JAX operations can seamlessly accept native Python types (like lists and scalars) and NumPy arrays as inputs. Conversely, JAX arrays can be converted to NumPy arrays using the standard `np.array()` constructor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc63aefc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# numpy arrays\n", + "a_np = np.array([1, 2])\n", + "b_np = np.array([3, 4])\n", + "# jax sum\n", + "print(f\"Sum of numpy arrays: {jnp.add(a_np, b_np)}\")\n", + "\n", + "# jax arrays\n", + "a_jax = jnp.array([1, 2])\n", + "b_jax = jnp.array([3, 4])\n", + "# jax sum\n", + "print(f\"Sum of jax arrays: {jnp.add(a_jax, b_jax)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d30f69c3-94b0-4854-a3f4-263322a1b59a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Convert JAX array to NumPy array\n", + "a_jax_to_np = np.array(a_jax)\n", + "print(f\"JAX array converted to NumPy: {a_jax_to_np}, type: {type(a_jax_to_np)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "2bfa1e94", + "metadata": {}, + "source": [ + "## Linear Regression" + ] + }, + { + "cell_type": "markdown", + "id": "e72a7b29", + "metadata": {}, + "source": [ + "Now let's use JAX operations to implement linear regression. Later in the course, you'll see abstracted ways to do this using high-level libraries like Equinox or Flax." + ] + }, + { + "cell_type": "markdown", + "id": "143a6a8f", + "metadata": {}, + "source": [ + "### Toy Dataset" + ] + }, + { + "cell_type": "markdown", + "id": "304481a4", + "metadata": {}, + "source": [ + "We'll model the following function: $y = 2x + 10$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "389cfbf4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "X_train = jnp.array(range(10), dtype=jnp.float32)\n", + "Y_train = 2 * X_train + 10\n", + "print(f\"X_train: {X_train}\")\n", + "print(f\"Y_train: {Y_train}\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2f41c5c", + "metadata": {}, + "source": [ + "Let's also create a test dataset to evaluate our models:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0801051b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "X_test = jnp.array(range(10, 20), dtype=jnp.float32)\n", + "Y_test = 2 * X_test + 10\n", + "print(f\"X_test: {X_test}\")\n", + "print(f\"Y_test: {Y_test}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1b060494", + "metadata": {}, + "source": [ + "#### Loss Function" + ] + }, + { + "cell_type": "markdown", + "id": "04fa8155", + "metadata": {}, + "source": [ + "A common baseline model is to predict the mean of the training target values. Let's calculate the Mean Squared Error (MSE) for this baseline on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd3b16b1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "y_mean = Y_test.mean()\n", + "\n", + "\n", + "def predict_mean(X):\n", + " return jnp.full_like(X, y_mean)\n", + "\n", + "\n", + "Y_hat_baseline = predict_mean(X_test)\n", + "baseline_errors = (Y_hat_baseline - Y_test) ** 2\n", + "baseline_loss = jnp.mean(baseline_errors)\n", + "print(f\"Baseline MSE Loss (predicting mean): {baseline_loss}\")" + ] + }, + { + "cell_type": "markdown", + "id": "920daa97", + "metadata": {}, + "source": [ + "Now, if $\\hat{Y}$ represents the vector containing our model's predictions when we use a linear regression model $\\hat{Y} = w_0X + w_1$, we can write a loss function taking as arguments the model parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da9b391f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def loss_mse(params, X, Y):\n", + " w0, w1 = params\n", + " Y_hat = w0 * X + w1\n", + " errors = (Y_hat - Y) ** 2\n", + " return jnp.mean(errors)" + ] + }, + { + "cell_type": "markdown", + "id": "999586e5", + "metadata": {}, + "source": [ + "### Gradient Function" + ] + }, + { + "cell_type": "markdown", + "id": "e1cb575b", + "metadata": {}, + "source": [ + "To use gradient descent, we need to take the partial derivatives of the loss function with respect to each of the weights. With JAX's automatic differentiation capability (`jax.grad`), we don't have to compute them manually! `jax.grad` transforms a function into a new function that computes its gradient. The `argnums` parameter specifies with respect to which argument(s) the gradient should be computed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3598ae6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# We want gradients with respect to params (arg 0)\n", + "grad_fn = jax.grad(loss_mse, argnums=0)\n", + "\n", + "# The loss function takes (params, X, Y)\n", + "\n", + "\n", + "def compute_gradients(params, X, Y):\n", + " return grad_fn(params, X, Y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03cda5ca", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "initial_params = [0.0, 0.0] # w0, w1 as a list or tuple\n", + "dw0, dw1 = compute_gradients(initial_params, X_train, Y_train)\n", + "\n", + "print(f\"Initial d_w0: {dw0}\")\n", + "print(f\"Initial d_w1: {dw1}\")" + ] + }, + { + "cell_type": "markdown", + "id": "52d487dc", + "metadata": {}, + "source": [ + "### Training Loop" + ] + }, + { + "cell_type": "markdown", + "id": "458abedc", + "metadata": {}, + "source": [ + "Here we have a very simple training loop. Note we are ignoring best practices like batching and random weight initialization for simplicity." + ] + }, + { + "cell_type": "markdown", + "id": "da391285", + "metadata": {}, + "source": [ + "**Exercise:** Complete the `for` loop below to train a linear regression.\n", + "1. Use `compute_gradients` to compute `dw0` and `dw1`.\n", + "2. Update `w0` and `w1` using the computed gradients and the `LEARNING_RATE`. Remember JAX arrays are immutable, so you'll create new arrays for the updated parameters.\n", + "3. For every 100th step, compute and print the `loss` using the `loss_mse` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8edde47", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import jax.numpy as jnp # ensure jnp is used for arrays\n", + "\n", + "STEPS = 1000\n", + "LEARNING_RATE = 0.02\n", + "MSG = \"STEP {step} - loss: {loss}, w0: {w0}, w1: {w1}\\n\"\n", + "\n", + "# Initialize parameters\n", + "params = [0.0, 0.0] # w0, w1\n", + "\n", + "for step in range(1, STEPS + 1):\n", + " grad_w0, grad_w1 = compute_gradients(params, X_train, Y_train)\n", + "\n", + " # Remember JAX arrays are immutable.\n", + " new_w0 = params[0] - LEARNING_RATE * grad_w0\n", + " new_w1 = params[1] - LEARNING_RATE * grad_w1\n", + " params = [new_w0, new_w1]\n", + "\n", + " if step % 100 == 0:\n", + " current_loss = loss_mse(params, X_train, Y_train)\n", + " print(\n", + " MSG.format(step=step, loss=current_loss, w0=params[0], w1=params[1])\n", + " )\n", + "\n", + "print(f\"Final parameters: w0={params[0]}, w1={params[1]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "493b2a51", + "metadata": {}, + "source": [ + "Now let's compare the test loss for this linear regression to the test loss from the baseline model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ad127ac", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "final_w0, final_w1 = params\n", + "test_loss = loss_mse(params, X_test, Y_test)\n", + "print(f\"Test MSE Loss (linear regression): {test_loss}\")\n", + "print(f\"Baseline MSE Loss (predicting mean): {baseline_loss}\")" + ] + }, + { + "cell_type": "markdown", + "id": "bafd7c4c", + "metadata": {}, + "source": [ + "This is indeed much better!" + ] + }, + { + "cell_type": "markdown", + "id": "8f27f078", + "metadata": {}, + "source": [ + "## Bonus" + ] + }, + { + "cell_type": "markdown", + "id": "4ab27042", + "metadata": {}, + "source": [ + "Try modelling a non-linear function such as: $y=xe^{-x^2}$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a9c4913", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "X = jnp.array(np.linspace(0, 2, 1000), dtype=jnp.float32)\n", + "Y = X * jnp.exp(-(X**2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6054012", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "plt.plot(X, Y)\n", + "plt.title(\"Non-linear function: y = x * exp(-x^2)\")\n", + "plt.xlabel(\"x\")\n", + "plt.ylabel(\"y\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "799eb02a", + "metadata": {}, + "source": [ + "To model this with a linear model, we need to engineer features. Let's create a function `make_features`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02375444", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def make_features(X):\n", + " f1 = jnp.ones_like(X) # Bias feature\n", + " f2 = X\n", + " f3 = X**2\n", + " f4 = X**3\n", + " f5 = jnp.sqrt(X)\n", + " f6 = jnp.exp(X)\n", + " # Stack them column-wise\n", + " return jnp.stack([f1, f2, f3, f4, f5, f6], axis=1)" + ] + }, + { + "cell_type": "markdown", + "id": "f63263c4", + "metadata": {}, + "source": [ + "We can reuse our `loss_mse` function, but we need a prediction function that works with matrix multiplication for features and weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4797934", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def predict(W, X):\n", + " return jnp.dot(X, W)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c26fb0c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def loss_mse(W, X, Y_true):\n", + " Y_hat = predict(W, X)\n", + " errors = (Y_hat - Y_true) ** 2\n", + " return jnp.mean(errors)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ae8337b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def compute_gradients(params_w, X_features, Y_true):\n", + " return jax.grad(loss_mse, argnums=0)(params_w, X_features, Y_true)" + ] + }, + { + "cell_type": "markdown", + "id": "67caf8e1", + "metadata": {}, + "source": [ + "Now, let's train our linear model on these engineered features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2036fda", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "STEPS = 2000\n", + "LEARNING_RATE = 0.02\n", + "\n", + "Xf = make_features(X)\n", + "n_features = Xf.shape[1]\n", + "\n", + "W = jnp.zeros(n_features)\n", + "\n", + "for step in range(1, STEPS + 1):\n", + " grads = compute_gradients(W, Xf, Y)\n", + " W = W - LEARNING_RATE * grads\n", + "\n", + " if step % 100 == 0:\n", + " current_loss = loss_mse(W, Xf, Y)\n", + " print(f\"Step: {step}, Loss: {current_loss}\")\n", + "\n", + "plt.plot(X, Y, label=\"Actual\")\n", + "plt.plot(X, predict(W, Xf), label=\"Predicted\")\n", + "plt.xlabel(\"x\")\n", + "plt.ylabel(\"y\")\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "47914b23", + "metadata": {}, + "source": [ + "Copyright 2025 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb1d6f81-9a3b-402c-9715-58ec40612d3b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "environment": { + "kernel": "conda-base-py", + "name": "workbench-notebooks.m133", + "type": "gcloud", + "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m133" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel) (Local)", + "language": "python", + "name": "conda-base-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 207c1cf794462d64638bc95292b49c1137713571 Mon Sep 17 00:00:00 2001 From: Takumi Ohyama Date: Thu, 18 Sep 2025 15:19:13 -0400 Subject: [PATCH 2/6] Update notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../solutions/1_introduction_to_jax.ipynb | 2 -- 1 file changed, 2 deletions(-) diff --git a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb index 2d2586232..0d75e80f5 100644 --- a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb +++ b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb @@ -128,8 +128,6 @@ }, "outputs": [], "source": [ - "a = jnp.array([5, 3, 8])\n", - "b = jnp.array([3, -1, 2])\n", "\n", "prod_multiply = jnp.multiply(a, b)\n", "prod_star = a * b\n", From c6d5f9e579c547333d09ffeba22c01cfa7b478bb Mon Sep 17 00:00:00 2001 From: Takumi Ohyama Date: Thu, 18 Sep 2025 15:19:40 -0400 Subject: [PATCH 3/6] Update notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../solutions/1_introduction_to_jax.ipynb | 1 - 1 file changed, 1 deletion(-) diff --git a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb index 0d75e80f5..9adae0d5d 100644 --- a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb +++ b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb @@ -144,7 +144,6 @@ }, "outputs": [], "source": [ - "a = jnp.array([5, 3, 8])\n", "\n", "exp_a = jnp.exp(a)\n", "print(f\"Exponential of a: {exp_a}\")" From 0ab67be0fffdf55dc4499e691730f00a05490bcb Mon Sep 17 00:00:00 2001 From: Takumi Ohyama Date: Thu, 18 Sep 2025 15:20:04 -0400 Subject: [PATCH 4/6] Update notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../solutions/1_introduction_to_jax.ipynb | 1 - 1 file changed, 1 deletion(-) diff --git a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb index 9adae0d5d..5c689dff6 100644 --- a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb +++ b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb @@ -420,7 +420,6 @@ }, "outputs": [], "source": [ - "import jax.numpy as jnp # ensure jnp is used for arrays\n", "\n", "STEPS = 1000\n", "LEARNING_RATE = 0.02\n", From 6427454f45db8adc5b51a67af5e1bf7121e5cc25 Mon Sep 17 00:00:00 2001 From: Takumi Ohyama Date: Thu, 18 Sep 2025 19:26:19 +0000 Subject: [PATCH 5/6] fixed format --- .../solutions/1_introduction_to_jax.ipynb | 5 ----- 1 file changed, 5 deletions(-) diff --git a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb index 5c689dff6..3d726f3ba 100644 --- a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb +++ b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb @@ -128,7 +128,6 @@ }, "outputs": [], "source": [ - "\n", "prod_multiply = jnp.multiply(a, b)\n", "prod_star = a * b\n", "print(f\"Product using jnp.multiply: {prod_multiply}\")\n", @@ -144,7 +143,6 @@ }, "outputs": [], "source": [ - "\n", "exp_a = jnp.exp(a)\n", "print(f\"Exponential of a: {exp_a}\")" ] @@ -361,8 +359,6 @@ "# We want gradients with respect to params (arg 0)\n", "grad_fn = jax.grad(loss_mse, argnums=0)\n", "\n", - "# The loss function takes (params, X, Y)\n", - "\n", "\n", "def compute_gradients(params, X, Y):\n", " return grad_fn(params, X, Y)" @@ -420,7 +416,6 @@ }, "outputs": [], "source": [ - "\n", "STEPS = 1000\n", "LEARNING_RATE = 0.02\n", "MSG = \"STEP {step} - loss: {loss}, w0: {w0}, w1: {w1}\\n\"\n", From 03380c3055f17dd6d984a3e64825226b13315b6b Mon Sep 17 00:00:00 2001 From: Takumi Ohyama Date: Mon, 23 Feb 2026 05:14:40 +0000 Subject: [PATCH 6/6] update kernelspec --- .../solutions/1_introduction_to_jax.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb index 3d726f3ba..5d43a0a09 100644 --- a/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb +++ b/asl_core/notebooks/introduction_to_tensorflow/solutions/1_introduction_to_jax.ipynb @@ -660,9 +660,9 @@ "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m133" }, "kernelspec": { - "display_name": "Python 3 (ipykernel) (Local)", + "display_name": "ASL Core", "language": "python", - "name": "conda-base-py" + "name": "asl_core" }, "language_info": { "codemirror_mode": { @@ -674,7 +674,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.18" + "version": "3.12.12" } }, "nbformat": 4,