diff --git a/Ch12_Optimization_Algorithms/Adadelta.ipynb b/Ch12_Optimization_Algorithms/Adadelta.ipynb new file mode 100644 index 00000000..3f966c44 --- /dev/null +++ b/Ch12_Optimization_Algorithms/Adadelta.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Adadelta" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In addition to RMSProp, Adadelta is another common optimization algorithm that\n", + "helps improve the chances of finding useful solutions at later stages of\n", + "iteration, which is difficult to do when using the Adagrad algorithm for the\n", + "same purpose :cite:`Zeiler.2012`. The interesting thing is that there is no learning rate\n", + "hyperparameter in the Adadelta algorithm." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The Algorithm\n", + "\n", + "Like RMSProp, the Adadelta algorithm uses the variable $\\boldsymbol{s}_t$, which is an EWMA on the squares of elements in mini-batch stochastic gradient $\\boldsymbol{g}_t$. At time step 0, all the elements are initialized to 0.\n", + "Given the hyperparameter $0 \\leq \\rho < 1$ (counterpart of $\\gamma$ in RMSProp), at time step $t>0$, compute using the same method as RMSProp:\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$$\\boldsymbol{s}_t \\leftarrow \\rho \\boldsymbol{s}_{t-1} + (1 - \\rho) \\boldsymbol{g}_t \\odot \\boldsymbol{g}_t. $$\n", + "\n", + "Unlike RMSProp, Adadelta maintains an additional state variable, $\\Delta\\boldsymbol{x}_t$ the elements of which are also initialized to 0 at time step 0. We use $\\Delta\\boldsymbol{x}_{t-1}$ to compute the variation of the independent variable:\n", + "\n", + "$$ \\boldsymbol{g}_t' \\leftarrow \\sqrt{\\frac{\\Delta\\boldsymbol{x}_{t-1} + \\epsilon}{\\boldsymbol{s}_t + \\epsilon}} \\odot \\boldsymbol{g}_t, $$\n", + "\n", + "Here, $\\epsilon$ is a constant added to maintain the numerical stability, such as $10^{-5}$. Next, we update the independent variable:\n", + "\n", + "$$\\boldsymbol{x}_t \\leftarrow \\boldsymbol{x}_{t-1} - \\boldsymbol{g}'_t. $$\n", + "\n", + "Finally, we use $\\Delta\\boldsymbol{x}$ to record the EWMA on the squares of elements in $\\boldsymbol{g}'$, which is the variation of the independent variable.\n", + "\n", + "$$\\Delta\\boldsymbol{x}_t \\leftarrow \\rho \\Delta\\boldsymbol{x}_{t-1} + (1 - \\rho) \\boldsymbol{g}'_t \\odot \\boldsymbol{g}'_t. $$\n", + "\n", + "As we can see, if the impact of $\\epsilon$ is not considered here, Adadelta differs from RMSProp in its replacement of the hyperparameter $\\eta$ with $\\sqrt{\\Delta\\boldsymbol{x}_{t-1}}$.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Implementation from Scratch\n", + "\n", + "Adadelta needs to maintain two state variables for each independent variable, $\\boldsymbol{s}_t$ and $\\Delta\\boldsymbol{x}_t$. We use the formula from the algorithm to implement Adadelta." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import d2l\n", + "from d2l import load_array\n", + "\n", + "def init_adadelta_states(feature_dim):\n", + " s_w, s_b = torch.zeros((feature_dim, 1)), torch.zeros(1)\n", + " delta_w, delta_b = torch.zeros((feature_dim, 1)), torch.zeros(1)\n", + " return ((s_w, delta_w), (s_b, delta_b))\n", + "\n", + "def adadelta(params, states, hyperparams):\n", + " rho, eps = hyperparams['rho'], 1e-5\n", + " for p, (s, delta) in zip(params, states):\n", + " p = p.type(torch.FloatTensor)\n", + " s[:] = rho * s + ((1 - rho) * p* p)\n", + " g = ((delta + eps).sqrt() / (s + eps).sqrt()) * (p)\n", + " p[:] -= g\n", + " delta[:] = rho * delta + (1 - rho) * g * g" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, we train the model with the hyperparameter $\\rho=0.9$." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss: 0.259, 0.075 sec/epoch\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAAC5CAYAAAA/IBqYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGdBJREFUeJzt3Xd4VWW2+PHvSkiRhNASekeKBCGANEEpKjKKyMyAIIKgoMNcZ2w/G2ObuVfnOjg614IiAlLEQSxgF2Y0wUGa9CpIlyagCERqyPr9cXbgACk7ZWefc1if58nD2e11vR6y2O19l6gqxhhTkCi/AzDGhAdLFsYYVyxZGGNcsWRhjHHFkoUxxhVLFsYYVzxNFiLSU0TWi8hGEXkkl+0jRGSViCwXkbki0sxZX09Ejjrrl4vImKBj2jjHbBSRF0VEvOyDMSZAvHrPQkSigQ3ANcAO4BvgZlVdG7RPkqoecj73Bv5LVXuKSD3gY1Vtnku7i4B7gAXAp8CLqvqZJ50wxpzm5ZlFO2Cjqm5W1RPANODG4B1yEoUjAcg3c4lIdSBJVedrIMtNBvqUbNjGmNx4mSxqAt8HLe9w1p1FRO4SkU3AKODuoE31RWSZiMwRkSuC2txRUJvGmJJXxsO2c7uXcN6Zg6qOBkaLyEDgMWAIsBuoo6o/ikgbYKaIpLptE0BE7gTuBIiPj29Tp06dovUiRGRnZxMVVfzcfjIbdmVmExcN1RJK7/52ScXvp0jow4YNG/arakpRjvUyWewAagct1wJ25bP/NOBVAFU9Dhx3Pi9xzjwaO23WctOmqo4FxgI0adJE169fX7RehIiMjAy6du1aIm1NW7SdR95fxQPXXcIdVzYokTYLUpLx+yUS+iAi24p6rJdp8hugkYjUF5FYYADwYfAOItIoaPF64DtnfYpzgxQRaQA0Ajar6m7gsIh0cJ6C3Ap84GEfIlL/trW5NrUqo2Z9y5pdB/0Ox4QJz5KFqmYBfwBmAeuA6aq6RkT+23nyAfAHEVkjIsuB+wlcggBcCawUkRXAu8AIVf3J2fZ7YBywEdgE2JOQQhIRnvlNCyolxHLPtOUcPXHK75BMGPDyMgRV/ZTA483gdU8Efb4nj+PeA97LY9ti4LxHqqZwKibE8ly/NAaNX8hfP13H//Sx/6Umf+F9t8YUS+dGydxxRX2mLNjGF+t+8DscE+IsWVzgHri2CZdUT+Khd1ey9/Axv8MxIcySxQUurkw0Lw5II/N4Fg++sxKbOc3kxZKFoVHVcjx2/SXM2bCPSfO2+h2OCVGWLAwAgzrUpXvTKvz1s29Zv+ew3+GYEGTJwgCBx6mj+rYgKb4M90xbxrGT9jjVnM2ShTktOTGOZ/u15Ns9hxn1eXi/8WpKniULc5ZuTaow9PJ6TPh6C3M27PM7HBNCLFmY8zzyq6Y0rprIA++s4MAvJ/wOx4QISxbmPPEx0fyjfxo/HznBYzNX2+NUA1iyMHlIrVGee69uzCerdvPhivwGC5sLhSULk6cRXRrSpm5FHp+5mt0Hj/odjvGZJQuTp+go4bl+LcnKVh58ZyXZ2XY5ciGzZGHyVS85gUevv4S5G/czZUGR500xEcCShSnQwHZ16NYkhf/9bB2b9mX6HY7xiSULUyAR4W+/bUF8TDT3v72ck6ey/Q7J+MCShXGlSlI8T/e5lBU7DjI6faPf4RgfhGpFsmtEZImzbYmIdA86JsNpM6daWRUv+2DOuL5Fdfqk1eClLzey4vuf/Q7HlDLPkoUz4e5o4FdAM+DmnGQQ5C1VvVRV0wjUDXneWb8fuEFVLyUwL+eUc467RVXTnJ+9XvXBnO8vNzanSrk47pu+3AabXWBCsiKZqi5T1Zw3gdYA8SIS52GsxqXyF8Xw934t2bzvF5757Fu/wzGlKJQrkuX4LbDMqSWS4w3nEuRxK4xc+jpdnMzQy+sxcd5W5n633+9wTCnxsjByP+BaVR3uLA8G2qnqH/PYf6Cz/5CgdakEao30UNVNzrqaqrpTRMoRmAH8TVWdnEt7pyuSpaSktJk+fXrJdrCUZWZmkpiY6HcYpx0/pTw57ygnTsH/dLqIhJj8c3aoxV8UkdCHbt26LVHVy4p0sKp68gN0BGYFLY8ERuazfxRwMGi5FoEq7J3yOWYo8HJBsTRu3FjDXXp6ut8hnGf59gPaYOQneu+0ZQXuG4rxF1Yk9AFYrEX8nQ7VimQVgE+c5PJ10P5lRCTZ+RwD9AJWe9gHk4+WtSvwx+4XM2PZTj5dtdvvcIzHQrUi2R+Ai4HHz3lEGgfMEpGVwHJgJ/C6V30wBbur28W0rFWeP81Yxd5DVkogkoVqRbKngKfyaLZNiQVoii0mOorn+6dx3Qv/4aH3VvLG0LbYPefIZG9wmmJrmJLIyF81JWP9Pt60wWYRy5KFKRG3dqxH1yYp/PmjtXxm9y8ikiULUyKiooTRA1uTVrsCd09bxpffWu3USGPJwpSYhLgyvHFbW5pWS2LEm0vtha0IY8nClKik+Bgm396OBskJ3DF5MYu2/OR3SKaEWLIwJa5iQixThrWnRoV4bp/4DctthGpEsGRhPJFSLo6pwztQKSGWW8cvZNshG6Ea7ixZGM9UKx/P1OHtSYwrw9+/OcZ3P1jB5XBmycJ4qnalsky9owNRUcLAcQvZsv8Xv0MyRWTJwniufnICD7WN51S2csvrC9hx4IjfIZkisGRhSkXNxCimDGtH5vEsBr6+kD0HbRxJuLFkYUpNao3yTB7Wnp9+OcEt4xawP/N4wQeZkGHJwpSqtNoVmDC0LTt/PsqgcQv5+YhVaQ8XlixMqWtXvxLjbm3L5v2/cOuERRw6dtLvkIwLliyMLzo3SmbMoNas232IweMXceAXO8MIdZYsjG+6N63KK7e0Yd3uQ/R7bb5Vag9xliyMr65pVpXJt7djz8Fj9H11vtVSDWEhWZHM2TbSOW69iFzrtk0Tfjo0qMy0OztwPOsU/cbMZ+UOG0sSikKyIpmz3wAgFegJvCIi0S7bNGGoec3yvDPicsrGRnPz2AU2vD0EhWRFMme/aap6XFW3ABud9gps04Sv+skJvPf7y6lVsSy3T/zGZgwPMaFakSyvY121acJX1aR4pv+uIy1qleeut5YydaHN6RkqvJzdO7cpns8rf6aqo4HRTkWyxwiUA8jr2NySW64l1c6pSEZGRoa7qENUZmZmWPehsPHf0Vg5eSSaR2esZsnq9dzQIMb3WcPD/TsoLi+TxQ6gdtByLWBXHvtC4JLiVRfHumpTVccCYwGaNGmiXbt2dRt3SMrIyCCc+1CU+Lt3zebhd1fy/rKdJKXU5IlezYiK8i9hhPt3UFwhWZHM2W+AiMSJSH2gEbDITZsmcsRER/H3fi0Z1rk+E+dt5f7pyzl5KtvvsC5Ynp1ZqGqWiORUJIsGJuRUJCNQb/FDAhXJrgZOAgdwKpI5+00H1gJZwF2qegogtza96oPxX1SU8Nj1l1ApIZZnZ63n56MneeWW1pSN9bQ+lslFSFYkc7Y9DTztpk0T2USEu7pdTKWEWB6dsYpB4xYyYWhbKpSN9Tu0C4q9wWnCxs3t6vDKLa1ZvfMQv3llHhtsmr5SZcnChJWezavz5vD2HDqWxY0vf80Hy3f6HdIFw5KFCTvt6lfik7s707xmEvdMW86TH6zmRJbd+PSaJQsTlqomxfPWHR0Y3rk+k+Zvo//Y+ez62UateslVshCRe0QkSQLGi8hSEenhdXDG5CcmOorHejXjlVtas2HPYXq9NJevN9qYEq+4PbO43RnH0QNIAW4DnvEsKmMK4bpLq/PhHztTOSGWweMXMjp9I9nZub7Ya4rBbbLIeW3uOuANVV1B7q9kG+OLhimJzLyrE71a1ODZWeu5c8piDh6x6fpKkttksUREZhNIFrNEpBxgd5RMSEmIK8MLA9L4S+9UMtbv44aX57Jm10G/w4oYbpPFMOARoK2qHgFiCFyKGBNSRIQhl9fj7d915ERWNr95ZR7TF39f8IGmQG6TRUdgvar+LCKDCIwOtZRtQlabuhX5+O7OtKlbkYfeXcnI91dy7KQVZy4Ot8niVeCIiLQEHgK2AZM9i8qYEpCcGMeUYe25q1tD/rnoe3r+31dMmb+VIyey/A4tLLlNFlmqqgRmpXpBVV8AynkXljElIzpKePDapky8rS3lL4rh8Q/W0PF/v2TU59/ywyEroVgYbgeSHRaRkcBg4ApnLswY78IypmR1bVKFLo1TWLLtAOP+s4Uxczbx+n82c0OLGtzeuT7Na5b3O8SQ5zZZ9AcGEnjfYo+I1AGe9S4sY0qeiHBZvUpcVq8S2388whvztjD9m+95f9lOOjSoxPDODejetIqvE+yEMleXIaq6B5gKlBeRXsAxVbV7FiZs1alclidvSGXeyKv403VN2f7jEYZPXsxVz89hyoJtdl8jF25f976JwExV/YCbgIUi0tfLwIwpDeUviuHOKxsy56FuvHhzK5Liy/D4zNVc/kzgvsZeu69xmtvLkEcJvGOxF0BEUoB/A+96FZgxpSkmOoreLWtwQ4vqp+9rvDpnE+PnbmHI5fUY0aWh3yH6zm2yiMpJFI4fsRGrJgIF39fY9uMvvPjFRsb9ZzNvLdzO1bWFNh1OUi7+wry37/YX/nMRmSUiQ0VkKPAJLqa2c1G+8H4RWSsiK0XkCxGp66zv5pQ0zPk5JiJ9nG0TRWRL0LY09901xr26lRN47qaWzLr3Sq5olMzMjSe5clQ6r3+1+YJ8wcvtDc4HCUyr3wJoCYxV1YfzO8ZlqcFlwGWq2oLAJc0o57+XrqppTlnD7sARYHbQcQ/mbFfV5W76YExRNapajlcHteHJjvFcWqsCT3+6ji7PpvPmgm0X1Gzjri8lVPU9Vb1fVe9T1RkuDnFTvjDdGWsCsIBAHZBz9QU+C9rPGF/ULx/N5NvbMe3ODtSuWJbHZq7mqufmMGPZDk5dAEPiJfBiZh4bRQ6Te8UvAVRVk/I5ti/QU1WHO8uDgfaq+oc89n8Z2KOqT52z/kvgeVX92FmeSGCsynHgC+ARVT2eS3vBFcnaTJ8+Pc9+hoPMzEwSExP9DqPIwj1+OLsPqsrK/ad4b8NJth/Opmai8JtGsbSuEu175bT8dOvWbYmqXlakg1XVkx8Cj1nHBS0PBl7KY99BBM4s4s5ZXx3YB8Scs06AOGAS8ERBsTRu3FjDXXp6ut8hFEu4x6+aex9OncrWj1bs1G7Ppmvdhz/W3i/P1Yz1ezU7O7v0A3SBQM2eIv1Oe/lEw1X5QqfI0KNAbz3/DOEmYIaqnp7FRFV3O/0+DrxB4HLHGF9ERQm9WtRg9n1XMqpvC/YfPs6QCYvo9dJcPlqxi6wIuqfhd/nCVsBrBBLF3lzauBn45znHVHf+FKAPsNqD2I0plDLRUdx0WW2+fKALo37bgqMnT/HHfy6j+3OBN0Ij4emJZ8lCVbOAnFKD64Dp6pQvFJHezm7PAonAO85j0NPJRETqETgzmXNO01NFZBWwCkgGnsKYEBFXJpqb2tbm3/d1YcygNlRKiOXxmavp/LcvGZ2+Mayn+vO7fOHV+Ry7FaiZy/ruJRiiMZ6IihJ6Nq/GtalVWbjlJ8bM2cSzs9bzSvpGBravw7DODahWPt7vMAvFqssa4yERoUODynRoUJm1uw7x2lebmPD1VibO20qftJr8rksDLq4SHlPD2CvbxpSSZjWSeGFAKzIe6MrAdnX4aOUurn7+K+6YvJh5m/aH/LsadmZhTCmrXaksf7mxOXdf1YhJ87Yyaf42/rX2ByolxHJV0ypcm1qNzo2SiY+J9jvUs1iyMMYnlRPjuL9HE0Z0bUj6t/uYvXYPn6/ZwztLdnBRTDRdGqfQI7Uq3ZtWoULZWL/DtWRhjN/Kxpbh+hbVub5FdU5kZbNwy4/MXvPD6eQRHSW0r1+JHs2q0iO1GjUqXORLnJYsjAkhsWWiuKJRClc0SuEvvVNZufMgs9fsYfbaH/jzR2v580drubRmeXo0q0rP5tVoVLX0bo5asjAmREVFCWm1K5BWuwIP9WzKpn2Z/GvtD8xes4fn/rWB5/61gdQaSfy6VU16t6xBlSRvH8VasjAmTDRMSaRhl0RGdGnI3kPH+GTVbmYu28lTn6zjr5+uo9PFyfRJq8m1zauRGFfyv9qWLIwJQ1WS4rmtU31u61SfTfsy+WDZTmYs38n/e2cFj85cRY9m1fh1q5pc0SiZMtEl84aEJQtjwlzDlETu79GE+65pzNLtB5ixbCcfr9zNhyt2UTkhlhta1uDXrWrSolbxaqNYsjAmQogIbepWok3dSjzRK5U5G/YxY9kO3lq0nYnzttIgOaFY7VuyMCYCxZaJ4ppmVbmmWVUOHj3J56t3M2PZzmK1aa97GxPhyl8UQ/+2dZh2Z8ditWPJwhjjiiULY4wrliyMMa5YsjDGuOJpsihqRTJn26mgqmPB0+3VF5GFIvKdiLztzO9pjPGYZ8miOBXJHEf1TNWx3kHr/wb8Q1UbAQeAYV71wRhzhpdnFiVVkew0Z0bv7pyp3j6JwAzfxhiPeflSVk3g+6DlHUD7fPYfBnwWtBwvIouBLOAZVZ0JVAZ+dmYOz2nzvEl94byKZGRkZBSlDyEjMzMzrPsQ7vFDZPShOLxMFrnVcMt1kkERGQRcBnQJWl1HVXeJSAPgS2f6/0Nu21TVsQSKOdOkSRPt2rVrIUIPPRkZGYRzH8I9foiMPhRHyFYkU9Vdzp+bgQygFbAfqCAiOUku1zaNMSUvJCuSiUhFEYlzPicDnYC1Tq3GdAKV1QGGAB942AdjjCNUK5JdAiwWkRUEksMzqrrW2fYwcL+IbCRwD2O8V30wxpwRkhXJVHUecGke2zZjxZCNKXX2BqcxxhVLFsYYVyxZGGNcsWRhjHHFkoUxxhVLFsYYVyxZGGNcsWRhjHHFkoUxxhVLFsYYVyxZGGNcsWRhjHHFkoUxxhVLFsYYVyxZGGNcsWRhjHHFkoUxxpWQrEgmImkiMl9E1jjb+gcdM1FEtgRVK0vzsg/GmIBQrUh2BLhVVVOBnsD/iUiFoOMeDKpWttyrPhhjzgjJimSqukFVv3M+7wL2AikexmqMKYCXySK3imS5Vg9znFuRDAARaQfEApuCVj/tXJ78I6dkgDHGW6FckQwRqQ5MAYaoarazeiSwh0ACGUugNMB/59KmlS8MIeEeP0RGH4pFVT35AToCs4KWRwIjc9nvagJ1Raqcsz4JWAr0y+e/0RX4uKBYGjdurOEuPT3d7xCKJdzjV42MPgCLtYi/06FakSwWmAFMVtV3zjmmuvOnEKigvtrDPhhjHJ5dhqhqlojkVCSLBiaoU5GMQHb7kLMrkgFsV9XewE3AlUBlERnqNDlUA08+popICoHLnOXACK/6YIw5I1Qrkr0JvJnHtu4lGaMxxh0JXMZENhE5DKz3O45iSiZQRT5chXv8EBl9aKKq5YpyoKdnFiFkvape5ncQxSEii8O5D+EeP0ROH4p6rI0NMca4YsnCGOPKhZIsxvodQAkI9z6Ee/xwgffhgrjBaYwpvgvlzMIYU0wRlSxczJ8RJyJvO9sXiki90o8yby7iHyoi+4Lm8hjuR5z5EZEJIrJXRHJ9s1YCXnT6uFJEWpd2jPlxEX9XETkY9B08kdt+fhKR2iKSLiLrnDlh7slln8J/D0V9TzzUfgi8JboJaEBgkNkKoNk5+/wXMMb5PAB42++4Cxn/UOBlv2MtoB9XAq2B1Xlsv47A6GIBOgAL/Y65kPF3xcV4JJ/7UB1o7XwuB2zI5e9Sob+HSDqzKHD+DGd5kvP5XeAqZ4xJKHATf8hT1a+An/LZ5UYCY35UVRcAFXLG+4QCF/GHPFXdrapLnc+HCQzUPHd6iEJ/D5GULNzMn3F6H1XNAg4ClUsluoK5nf/jt85p47siUrt0QitRhZ3nJBR1FJEVIvKZiKT6HUx+nEvtVsDCczYV+nuIpGThZv4M13Ns+MBNbB8B9TQwDeG/OXOWFE5C+TtwYylQV1VbAi8BM32OJ08ikgi8B9yrqofO3ZzLIfl+D5GULHYAwf/S1gJ25bWPiJQByhM6p5wFxq+qP6rqcWfxdaBNKcVWktx8TyFLVQ+paqbz+VMgRkSSfQ7rPCISQyBRTFXV93PZpdDfQyQliwLnz3CWhzif+wJfqnO3JwS4mf8j+JqyN4Fr0XDzIXCrcze+A3BQVXf7HZRbIlIt5z6XM+VjFPCjv1GdzYlvPLBOVZ/PY7dCfw8RM5BM3c2fMR6YIiIbCZxRDPAv4rO5jP9uEekNZBGIf6hvAedBRP5J4IlBsojsAJ4EYgBUdQyBKQuuAzYSmMX9Nn8izZ2L+PsCvxeRLOAoMCCE/sHJ0QkYDKwSkZzZ7/8E1IGifw/2BqcxxpVIugwxxnjIkoUxxhVLFsYYVyxZGGNcsWRhjHHFkoUJGc6Izo/9jsPkzpKFMcYVSxam0ERkkIgscuZzeE1EokUkU0SeE5GlIvKFUwgKEUkTkQXO4LcZIlLRWX+xiPzbGZC1VEQaOs0nOoPkvhWRqSE0KviCZ8nCFIqIXAL0BzqpahpwCrgFSACWqmprYA6BNx8BJgMPO4PfVgWtnwqMdgZkXQ7kvGrcCrgXaEZgbo9OnnfKuBIxr3ubUnMVgQFs3zj/6F8E7AWygbedfd4E3heR8kAFVZ3jrJ9EoFRlOaCmqs4AUNVjAE57i1R1h7O8HKgHzPW+W6YglixMYQkwSVVHnrVS5PFz9stvHEF+lxbHgz6fwv6Ohgy7DDGF9QXQV0SqAIhIJRGpS+DvUl9nn4HAXFU9CBwQkSuc9YOBOc7cCjtEpI/TRpyIlC3VXphCs6xtCkVV14rIY8BsEYkCTgJ3Ab8AqSKyhMAMZP2dQ4YAY5xksJkzoxsHA685o2pPAv1KsRumCGzUqSkRIpKpqol+x2G8Y5chxhhX7MzCGOOKnVkYY1yxZGGMccWShTHGFUsWxhhXLFkYY1yxZGGMceX/A9z+PW77bMwbAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "data_iter, feature_dim = d2l.get_data_ch10(batch_size=10)\n", + "d2l.train_ch10(torch.optim.Adadelta, {'rho': 0.9}, data_iter, feature_dim);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "* Adadelta has no learning rate hyperparameter, it uses an EWMA on the squares of elements in the variation of the independent variable to replace the learning rate.\n", + "\n", + "## Exercises\n", + "\n", + "* Adjust the value of $\\rho$ and observe the experimental results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}