diff --git a/src/python/tensorflow_cloud/core/experimental/tests/examples/running_custom_task_experiment_from_tf_model_garden_on_gcp_with_tf_cloud.ipynb b/src/python/tensorflow_cloud/core/experimental/tests/examples/running_custom_task_experiment_from_tf_model_garden_on_gcp_with_tf_cloud.ipynb
new file mode 100644
index 00000000..81d11e22
--- /dev/null
+++ b/src/python/tensorflow_cloud/core/experimental/tests/examples/running_custom_task_experiment_from_tf_model_garden_on_gcp_with_tf_cloud.ipynb
@@ -0,0 +1,454 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "Running custom tasks experiments from TF Model Garden on GCP with TF Cloud",
+ "provenance": [],
+ "collapsed_sections": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cIG5d4Kvls6m"
+ },
+ "source": [
+ "##### Copyright 2021 The TensorFlow Cloud Authors.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "eR70XKMMmC8I",
+ "cellView": "form"
+ },
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wKcTRRxsAmDl"
+ },
+ "source": [
+ "# Running custom tasks experiments from TF Model Garden on GCP with TF Cloud\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FAUbwFuJB3bw"
+ },
+ "source": [
+ "In this example we will use [run_experiment_cloud](https://github.com/tensorflow/cloud/blob/690c3eee65dadee8af260a19341ff23f42f1f070/src/python/tensorflow_cloud/core/experimental/models.py#L230) from the experimental module of TF Cloud to train a ResNet model from [TF Model Garden](https://github.com/tensorflow/models/tree/master/official) on an image classification task with the cifar 10 dataset from TFDS. We will also be showing the different distribution strategies that this method supports."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EFCSAVDbC8-W"
+ },
+ "source": [
+ "## Install Packages\n",
+ "\n",
+ "We need the nightly version of tensorflow-cloud that we can get from github, the official release of tf-models-official, and keras 2.6.0rc0 for compatibility."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "umK22_FY9iJw"
+ },
+ "source": [
+ "!pip install -q tensorflow-cloud tf-models-nightly"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1BjT1di_uVQi",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "e1a3fec4-f557-48d3-ddf4-3eca050ce089"
+ },
+ "source": [
+ "import tensorflow_cloud as tfc\n",
+ "print(tfc.__version__)"
+ ],
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "0.1.16\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iZ-0PtcKhIqz"
+ },
+ "source": [
+ "## Import required modules"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "S31gRAUTfOTM"
+ },
+ "source": [
+ "import copy\n",
+ "import os\n",
+ "import sys\n",
+ "\n",
+ "from tensorflow_cloud.core.experimental.models import run_experiment_cloud\n",
+ "\n",
+ "from official.core import config_definitions as cfg\n",
+ "from official.core import exp_factory\n",
+ "from official.core import task_factory\n",
+ "from official.modeling import optimization\n",
+ "from official.vision.beta.configs.backbones import Backbone\n",
+ "from official.vision.beta.configs.backbones import ResNet\n",
+ "from official.vision.beta.configs import common\n",
+ "from official.vision.beta.configs import image_classification"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "siarj0TEhMzb"
+ },
+ "source": [
+ "## Project Configurations\n",
+ "Setting project parameters. For more details on Google Cloud Specific parameters please refer to [Google Cloud Project Setup Instructions](https://www.kaggle.com/nitric/google-cloud-project-setup-instructions/)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "YCeDOU9Ufny9"
+ },
+ "source": [
+ "# Set Google Cloud Specific parameters\n",
+ "\n",
+ "# TODO: Please set GCP_PROJECT_ID to your own Google Cloud project ID.\n",
+ "GCP_PROJECT_ID = 'YOUR_PROJECT_ID' #@param {type:\"string\"}\n",
+ "\n",
+ "# TODO: set GCS_BUCKET to your own Google Cloud Storage (GCS) bucket.\n",
+ "GCS_BUCKET = 'YOUR_BUCKET_NAME' #@param {type:\"string\"}\n",
+ "\n",
+ "# DO NOT CHANGE: Currently only the 'us-central1' region is supported.\n",
+ "REGION = 'us-central1'\n",
+ "\n",
+ "# OPTIONAL: You can change the job name to any string.\n",
+ "JOB_NAME = 'cifar10_resnet' #@param {type:\"string\"}\n",
+ "\n",
+ "# Setting location were training logs and checkpoints will be stored\n",
+ "GCS_BASE_PATH = f'gs://{GCS_BUCKET}/{JOB_NAME}'\n",
+ "MODEL_DIR = os.path.join(GCS_BASE_PATH,\"model\")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mo6Wvg10DraI"
+ },
+ "source": [
+ "## Authenticating the notebook to use your Google Cloud Project\n",
+ "\n",
+ "This code authenticates the notebook, checking your valid Google Cloud credentials and identity. It is inside the `if not tfc.remote()` block to ensure that it is only run in the notebook, and will not be run when the notebook code is sent to Google Cloud.\n",
+ "\n",
+ "Note: For Kaggle Notebooks click on \"Add-ons\"->\"Google Cloud SDK\" before running the cell below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "QeAmVS5KDtlR"
+ },
+ "source": [
+ "if not tfc.remote():\n",
+ "\n",
+ " # Authentication for Kaggle Notebooks\n",
+ " if \"kaggle_secrets\" in sys.modules:\n",
+ " from kaggle_secrets import UserSecretsClient\n",
+ " UserSecretsClient().set_gcloud_credentials(project=GCP_PROJECT_ID)\n",
+ "\n",
+ " # Authentication for Colab Notebooks\n",
+ " if \"google.colab\" in sys.modules:\n",
+ " from google.colab import auth\n",
+ " auth.authenticate_user()\n",
+ " os.environ[\"GOOGLE_CLOUD_PROJECT\"] = GCP_PROJECT_ID"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "E1Fy7xadhSRi"
+ },
+ "source": [
+ "## Set Up TF Model Garden Experiment\n",
+ "\n",
+ "We are going to set up the experiment from TF Model Garden that we want to run. In this case, we are going to be training a ResNet model on an image classification task with the cifar 10 dataset from TFDS."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "IKu5XD9MhbMm"
+ },
+ "source": [
+ "@exp_factory.register_config_factory('resnet_cifar10')\n",
+ "def image_classification_cifar10() -> cfg.ExperimentConfig:\n",
+ " \"\"\"Image classification on cifar10 with resnet.\"\"\"\n",
+ " tfds_name = 'cifar10'\n",
+ " train_examples = 50000\n",
+ " val_examples = 10000\n",
+ " train_batch_size = 256\n",
+ " eval_batch_size = 256\n",
+ " num_classes = 10\n",
+ " steps_per_epoch = train_examples // train_batch_size\n",
+ " config = cfg.ExperimentConfig(\n",
+ " task=image_classification.ImageClassificationTask(\n",
+ " model=image_classification.ImageClassificationModel(\n",
+ " num_classes=num_classes,\n",
+ " input_size=[224, 224, 3],\n",
+ " backbone=Backbone(\n",
+ " type='resnet', resnet=ResNet(model_id=50)),\n",
+ " norm_activation=common.NormActivation(\n",
+ " norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),\n",
+ " losses=image_classification.Losses(l2_weight_decay=1e-4),\n",
+ " train_data=image_classification.DataConfig(\n",
+ " tfds_name=tfds_name,\n",
+ " tfds_split='train',\n",
+ " is_training=True,\n",
+ " global_batch_size=train_batch_size),\n",
+ " validation_data=image_classification.DataConfig(\n",
+ " tfds_name=tfds_name,\n",
+ " tfds_split='test',\n",
+ " is_training=False,\n",
+ " global_batch_size=eval_batch_size)),\n",
+ " trainer=cfg.TrainerConfig(\n",
+ " steps_per_loop=steps_per_epoch,\n",
+ " summary_interval=steps_per_epoch,\n",
+ " checkpoint_interval=steps_per_epoch,\n",
+ " train_steps=90 * steps_per_epoch,\n",
+ " validation_steps=val_examples // eval_batch_size,\n",
+ " validation_interval=steps_per_epoch,\n",
+ " optimizer_config=optimization.OptimizationConfig({\n",
+ " 'optimizer': {\n",
+ " 'type': 'sgd',\n",
+ " 'sgd': {\n",
+ " 'momentum': 0.9\n",
+ " }\n",
+ " },\n",
+ " 'learning_rate': {\n",
+ " 'type': 'stepwise',\n",
+ " 'stepwise': {\n",
+ " 'boundaries': [\n",
+ " 30 * steps_per_epoch, 60 * steps_per_epoch,\n",
+ " 80 * steps_per_epoch\n",
+ " ],\n",
+ " 'values': [\n",
+ " 0.1 * train_batch_size / 256,\n",
+ " 0.01 * train_batch_size / 256,\n",
+ " 0.001 * train_batch_size / 256,\n",
+ " 0.0001 * train_batch_size / 256,\n",
+ " ]\n",
+ " }\n",
+ " },\n",
+ " 'warmup': {\n",
+ " 'type': 'linear',\n",
+ " 'linear': {\n",
+ " 'warmup_steps': 5 * steps_per_epoch,\n",
+ " 'warmup_learning_rate': 0\n",
+ " }\n",
+ " }\n",
+ " })),\n",
+ " restrictions=[\n",
+ " 'task.train_data.is_training != None',\n",
+ " 'task.validation_data.is_training != None'\n",
+ " ])\n",
+ "\n",
+ " return config\n"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HZUJ9AvUM4Hv"
+ },
+ "source": [
+ "\n",
+ "After having the experiment config ready, we can store all of the params in a dictionary. For more details refer to [run_experiment GitHhub](https://github.com/tensorflow/models/blob/7c2ff1afc4423266223bcd50cba0ed55aca826c8/official/core/train_lib.py#L35).\n",
+ "\n",
+ "Note: run_experiment requires a distribution_strategy parameter. However, run_experiment_cloud selects the distribution strategy based on the cloud configuration. Therefore, you should not pass this parameter as part of run_experiment_kwargs. For more information on distribution strategies check [Running model experiments from TF Model Garden on GCP with TF Cloud](https://github.com/tensorflow/cloud/blob/master/src/python/tensorflow_cloud/core/experimental/tests/examples/running_model_experiments_from_tf_model_garden_on_gcp_with_tf_cloud.ipynb)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "qGhFG2bJAWjE"
+ },
+ "source": [
+ "from official.vision.beta.tasks.image_classification import ImageClassificationTask\n",
+ "\n",
+ "config = exp_factory.get_exp_config('resnet_cifar10')\n",
+ "\n",
+ "run_experiment_kwargs = dict(\n",
+ " params=config,\n",
+ " task=task_factory.get_task(config.task),\n",
+ " mode=\"train_and_eval\",\n",
+ " model_dir=MODEL_DIR,\n",
+ ")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "E94YK3sGikkx"
+ },
+ "source": [
+ "## Set up TensorFlowCloud run\n",
+ "\n",
+ "Setting up parameters for tfc.run(). The chief_config, worker_count and worker_config will be set up individually for each distribution strategy. For more details refer to [TensorFlow Cloud overview tutorial](https://colab.research.google.com/github/tensorflow/cloud/blob/master/g3doc/tutorials/overview.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "C8VWQ3AANj3V"
+ },
+ "source": [
+ "with open('requirements.txt','w') as f:\n",
+ " f.write('tf-models-nightly\\n')\n",
+ "\n",
+ "run_kwargs = dict(\n",
+ " requirements_txt = 'requirements.txt',\n",
+ " docker_config=tfc.DockerConfig(\n",
+ " parent_image=\"gcr.io/deeplearning-platform-release/tf2-gpu.2-5\",\n",
+ " image_build_bucket=GCS_BUCKET\n",
+ " ),\n",
+ " chief_config=tfc.COMMON_MACHINE_CONFIGS[\"T4_4X\"],\n",
+ " job_labels={'job': JOB_NAME}\n",
+ ")"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tWhfuseFPnBa"
+ },
+ "source": [
+ "## Run remote experiment\n",
+ "\n",
+ "With run_experiment_kwargs and run_kwargs complete, we can call now run_experiment_cloud to run the experiment in GCP."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "jRuqBEKdREeB"
+ },
+ "source": [
+ "run_experiment_cloud(run_experiment_kwargs, run_kwargs)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZLFHnIFUF8Fx"
+ },
+ "source": [
+ "# Training Results\n",
+ "## Reconnect your Colab instance\n",
+ "Most remote training jobs are long running, if you are using Colab it may time out before the training results are available. In that case rerun the following sections to reconnect and configure your Colab instance to access the training results. Run the following sections in order:\n",
+ "\n",
+ "1. Import required modules\n",
+ "2. Project Configurations\n",
+ "3. Authenticating the notebook to use your Google Cloud Project"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tNBMCk0MFi1A"
+ },
+ "source": [
+ "## Load your trained model\n",
+ "\n",
+ "Once training is complete, you can retrieve your model from the GCS Bucket you specified above."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "B7vpXm3l9iiT"
+ },
+ "source": [
+ "import tensorflow as tf\n",
+ "\n",
+ "trained_model = tf.keras.models.load_model(MODEL_DIR)\n",
+ "trained_model.summary()"
+ ],
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file