From ab27cb64a091461e19132db2086b136937e23a5f Mon Sep 17 00:00:00 2001 From: David <2524133+dcato98@users.noreply.github.com> Date: Fri, 7 Aug 2020 18:26:33 +0100 Subject: [PATCH] allow on-the-fly regularization changes --- src/nn.ts | 17 +++++++---------- src/playground.ts | 10 ++++++---- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/nn.ts b/src/nn.ts index e92a13de..5c4d0ebe 100644 --- a/src/nn.ts +++ b/src/nn.ts @@ -136,7 +136,7 @@ export class Activations { }; } -/** Build-in regularization functions */ +/** Built-in regularization functions */ export class RegularizationFunction { public static L1: RegularizationFunction = { output: w => Math.abs(w), @@ -176,12 +176,10 @@ export class Link { * @param regularization The regularization function that computes the * penalty for this weight. If null, there will be no regularization. */ - constructor(source: Node, dest: Node, - regularization: RegularizationFunction, initZero?: boolean) { + constructor(source: Node, dest: Node, initZero?: boolean) { this.id = source.id + "-" + dest.id; this.source = source; this.dest = dest; - this.regularization = regularization; if (initZero) { this.weight = 0; } @@ -204,7 +202,6 @@ export class Link { export function buildNetwork( networkShape: number[], activation: ActivationFunction, outputActivation: ActivationFunction, - regularization: RegularizationFunction, inputIds: string[], initZero?: boolean): Node[][] { let numLayers = networkShape.length; let id = 1; @@ -230,7 +227,7 @@ export function buildNetwork( // Add links from nodes in the previous layer to this node. for (let j = 0; j < network[layerIdx - 1].length; j++) { let prevNode = network[layerIdx - 1][j]; - let link = new Link(prevNode, node, regularization, initZero); + let link = new Link(prevNode, node, initZero); prevNode.outputs.push(link); node.inputLinks.push(link); } @@ -333,7 +330,7 @@ export function backProp(network: Node[][], target: number, * derivatives. */ export function updateWeights(network: Node[][], learningRate: number, - regularizationRate: number) { + regularization: RegularizationFunction, regularizationRate: number) { for (let layerIdx = 1; layerIdx < network.length; layerIdx++) { let currentLayer = network[layerIdx]; for (let i = 0; i < currentLayer.length; i++) { @@ -350,8 +347,8 @@ export function updateWeights(network: Node[][], learningRate: number, if (link.isDead) { continue; } - let regulDer = link.regularization ? - link.regularization.der(link.weight) : 0; + let regulDer = regularization ? + regularization.der(link.weight) : 0; if (link.numAccumulatedDers > 0) { // Update the weight based on dE/dw. link.weight = link.weight - @@ -359,7 +356,7 @@ export function updateWeights(network: Node[][], learningRate: number, // Further update the weight based on regularization. let newLinkWeight = link.weight - (learningRate * regularizationRate) * regulDer; - if (link.regularization === RegularizationFunction.L1 && + if (regularization === RegularizationFunction.L1 && link.weight * newLinkWeight < 0) { // The weight crossed 0 due to the regularization term. Set it to 0. link.weight = 0; diff --git a/src/playground.ts b/src/playground.ts index aeac0f9c..1dc62481 100644 --- a/src/playground.ts +++ b/src/playground.ts @@ -341,7 +341,8 @@ function makeGUI() { function() { state.regularization = regularizations[this.value]; parametersChanged = true; - reset(); + state.serialize(); + userHasInteracted(); }); regularDropdown.property("value", getKeyFromValue(regularizations, state.regularization)); @@ -349,7 +350,8 @@ function makeGUI() { let regularRate = d3.select("#regularRate").on("change", function() { state.regularizationRate = +this.value; parametersChanged = true; - reset(); + state.serialize(); + userHasInteracted(); }); regularRate.property("value", state.regularizationRate); @@ -913,7 +915,7 @@ function oneStep(): void { nn.forwardProp(network, input); nn.backProp(network, point.label, nn.Errors.SQUARE); if ((i + 1) % state.batchSize === 0) { - nn.updateWeights(network, state.learningRate, state.regularizationRate); + nn.updateWeights(network, state.learningRate, state.regularization, state.regularizationRate); } }); // Compute the loss. @@ -956,7 +958,7 @@ function reset(onStartup=false) { let outputActivation = (state.problem === Problem.REGRESSION) ? nn.Activations.LINEAR : nn.Activations.TANH; network = nn.buildNetwork(shape, state.activation, outputActivation, - state.regularization, constructInputIds(), state.initZero); + constructInputIds(), state.initZero); lossTrain = getLoss(network, trainData); lossTest = getLoss(network, testData); drawNetwork(network);