Skip to content

Commit 4741e39

Browse files
committed
Inline createSlots into applyGradients
Create optimizer variables directly at the place they are used.
1 parent 0191501 commit 4741e39

File tree

9 files changed

+45
-111
lines changed

9 files changed

+45
-111
lines changed

tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/AdaDelta.kt

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,7 @@ public class AdaDelta(
7474
epsilonConstant = tf.constant(epsilon, getDType())
7575

7676
for ((i, variable) in weights.withIndex()) {
77-
val varName = variable.ref().op().name()
78-
79-
val accumSlot: Variable<Float> = getSlot(varName, ACCUMULATOR)
80-
val accumUpdateSlot: Variable<Float> = getSlot(varName, ACCUMULATOR_UPDATE)
77+
val (accumSlot, accumUpdateSlot) = createAdaDeltaSlot(graph, tf, variable.asOutput())
8178

8279
targets.add(
8380
tf.train.applyAdadelta(
@@ -107,10 +104,6 @@ public class AdaDelta(
107104
return accumulator to accumulatorUpdate
108105
}
109106

110-
override fun createSlots(graph: KGraph, tf: Ops, variables: List<Output<Float>>): List<Variable<Float>> {
111-
return variables.flatMap { createAdaDeltaSlot(graph, tf, it.asOutput()).toList() }
112-
}
113-
114107
override val optimizerName: String get() = "Adadelta"
115108

116109
override val isRunningOnGPU: Boolean get() = true

tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/AdaGrad.kt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ public class AdaGrad(
6464
learningRateConst = tf.constant(learningRate, getDType())
6565

6666
for ((i, variable) in weights.withIndex()) {
67-
val varName = variable.ref().op().name()
68-
69-
val slot: Variable<Float> = getSlot(varName, ACCUMULATOR)
67+
val slot = createAdaGradSlot(graph, tf, variable.asOutput())
7068

7169
targets.add(
7270
tf.train.applyAdagrad(
@@ -90,10 +88,6 @@ public class AdaGrad(
9088
return createSlot(graph, tf, v.asOutput(), ACCUMULATOR, initializer)
9189
}
9290

93-
override fun createSlots(graph: KGraph, tf: Ops, variables: List<Output<Float>>): List<Variable<Float>> {
94-
return variables.map { createAdaGradSlot(graph, tf, it.asOutput()) }
95-
}
96-
9791
override val optimizerName: String get() = "Adagrad"
9892

9993
override val isRunningOnGPU: Boolean get() = true

tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/AdaGradDA.kt

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,14 @@ public class AdaGradDA(
7676
l1StrengthConst = tf.constant(l1Strength, getDType())
7777
l2StrengthConst = tf.constant(l2Strength, getDType())
7878

79-
for ((i, variable) in weights.withIndex()) {
80-
val varName = variable.ref().op().name()
81-
82-
val gradSlot: Variable<Float> = getSlot(varName, ACCUMULATOR)
83-
val gradSquaredSlot: Variable<Float> = getSlot(varName, SQUARED_ACCUMULATOR)
79+
globalStep = tf.withName(GLOBAL_STEP).variable(Shape.scalar(), getDType())
80+
val globalStepAssignName = defaultAssignOpName(GLOBAL_STEP)
81+
val globalStepInit: Assign<*> = tf.withName(globalStepAssignName)
82+
.assign(globalStep, tf.withName(defaultInitializerOpName(GLOBAL_STEP)).constant(0.0f))
83+
graph.addOptimizerVariableInitializer(globalStepInit)
8484

85+
for ((i, variable) in weights.withIndex()) {
86+
val (gradSlot, gradSquaredSlot) = createAdaGradDASlot(graph, tf, variable.asOutput())
8587
targets.add(
8688
tf.train.applyAdagradDa(
8789
variable,
@@ -117,15 +119,6 @@ public class AdaGradDA(
117119
return accumulator to squaredAccumulator
118120
}
119121

120-
override fun createSlots(graph: KGraph, tf: Ops, variables: List<Output<Float>>): List<Variable<Float>> {
121-
globalStep = tf.withName(GLOBAL_STEP).variable(Shape.scalar(), getDType())
122-
val globalStepAssignName = defaultAssignOpName(GLOBAL_STEP)
123-
val globalStepInit: Assign<*> = tf.withName(globalStepAssignName)
124-
.assign(globalStep, tf.withName(defaultInitializerOpName(GLOBAL_STEP)).constant(0.0f))
125-
graph.addOptimizerVariableInitializer(globalStepInit)
126-
return variables.flatMap { createAdaGradDASlot(graph, tf, it.asOutput()).toList() }
127-
}
128-
129122
override val optimizerName: String get() = "AdaGradDA"
130123

131124
override val isRunningOnGPU: Boolean get() = true

tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Adam.kt

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,26 @@ public class Adam(
8080
learningRateConst = tf.constant(learningRate, getDType())
8181
epsilonConstant = tf.constant(epsilon, getDType())
8282

83-
for ((i, variable) in weights.withIndex()) {
84-
val varName = variable.ref().op().name()
83+
betaOnePower = tf.withName(FIRST_BETA_POWER_NAME).variable(Shape.scalar(), getDType())
84+
val betaOnePowerAssignName = defaultAssignOpName(FIRST_BETA_POWER_NAME)
85+
val betaOnePowerInit: Assign<*> = tf.withName(betaOnePowerAssignName)
86+
.assign(
87+
betaOnePower,
88+
tf.withName(defaultInitializerOpName(FIRST_BETA_POWER_NAME)).constant(beta1, getDType())
89+
)
90+
graph.addOptimizerVariableInitializer(betaOnePowerInit)
8591

86-
val firstMomentSlot: Variable<Float> = getSlot(varName, FIRST_MOMENT)
87-
val secondMomentSlot: Variable<Float> = getSlot(varName, SECOND_MOMENT)
92+
betaTwoPower = tf.withName(SECOND_BETA_POWER_NAME).variable(Shape.scalar(), getDType())
93+
val betaTwoPowerAssignName = defaultAssignOpName(SECOND_BETA_POWER_NAME)
94+
val betaTwoPowerInit: Assign<*> = tf.withName(betaTwoPowerAssignName)
95+
.assign(
96+
betaTwoPower,
97+
tf.withName(defaultInitializerOpName(SECOND_BETA_POWER_NAME)).constant(beta2, getDType())
98+
)
99+
graph.addOptimizerVariableInitializer(betaTwoPowerInit)
88100

101+
for ((i, variable) in weights.withIndex()) {
102+
val (firstMomentSlot, secondMomentSlot) = createAdamSlot(graph, tf, variable.asOutput())
89103
targets.add(
90104
tf.train.applyAdam(
91105
variable,
@@ -132,31 +146,6 @@ public class Adam(
132146
return firstMoment to secondMoment
133147
}
134148

135-
override fun createSlots(graph: KGraph, tf: Ops, variables: List<Output<Float>>): List<Variable<Float>> {
136-
betaOnePower = tf.withName(FIRST_BETA_POWER_NAME).variable(Shape.scalar(), getDType())
137-
138-
val betaOnePowerAssignName = defaultAssignOpName(FIRST_BETA_POWER_NAME)
139-
val betaOnePowerInit: Assign<*> = tf.withName(betaOnePowerAssignName)
140-
.assign(
141-
betaOnePower,
142-
tf.withName(defaultInitializerOpName(FIRST_BETA_POWER_NAME)).constant(beta1, getDType())
143-
)
144-
graph.addOptimizerVariableInitializer(betaOnePowerInit)
145-
146-
147-
betaTwoPower = tf.withName(SECOND_BETA_POWER_NAME).variable(Shape.scalar(), getDType())
148-
149-
val betaTwoPowerAssignName = defaultAssignOpName(SECOND_BETA_POWER_NAME)
150-
val betaTwoPowerInit: Assign<*> = tf.withName(betaTwoPowerAssignName)
151-
.assign(
152-
betaTwoPower,
153-
tf.withName(defaultInitializerOpName(SECOND_BETA_POWER_NAME)).constant(beta2, getDType())
154-
)
155-
graph.addOptimizerVariableInitializer(betaTwoPowerInit)
156-
157-
return variables.flatMap { createAdamSlot(graph, tf, it.asOutput()).toList() }
158-
}
159-
160149
override val optimizerName: String get() = "Adam"
161150

162151
override val isRunningOnGPU: Boolean get() = true

tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Adamax.kt

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,19 @@ public class Adamax(
7979
learningRateConst = tf.constant(learningRate, getDType())
8080
epsilonConstant = tf.constant(epsilon, getDType())
8181

82+
betaOnePower = tf.withName(FIRST_BETA_POWER_NAME).variable(Shape.scalar(), getDType())
83+
val betaOnePowerAssignName = defaultAssignOpName(FIRST_BETA_POWER_NAME)
84+
val betaOnePowerInit: Assign<*> = tf.withName(betaOnePowerAssignName)
85+
.assign(
86+
betaOnePower,
87+
tf.withName(defaultInitializerOpName(FIRST_BETA_POWER_NAME)).constant(beta1, getDType())
88+
)
89+
graph.addOptimizerVariableInitializer(betaOnePowerInit)
90+
8291
val scope = Scope(graph.tfGraph)
8392

8493
for ((i, variable) in weights.withIndex()) {
85-
val varName = variable.ref().op().name()
86-
87-
val firstMomentSlot: Variable<Float> = getSlot(varName, FIRST_MOMENT)
88-
val secondMomentSlot: Variable<Float> = getSlot(varName, SECOND_MOMENT)
89-
94+
val (firstMomentSlot, secondMomentSlot) = createAdamaxSlot(graph, tf, variable.asOutput())
9095
targets.add(
9196
ApplyAdaMax.create(
9297
scope,
@@ -104,10 +109,9 @@ public class Adamax(
104109
)
105110
}
106111

107-
val betaOnePowerInit = tf
108-
.assign(betaOnePower, tf.math.mul(betaOnePower, betaOneConst))
112+
val betaOnePowerInit2 = tf.assign(betaOnePower, tf.math.mul(betaOnePower, betaOneConst))
109113

110-
graph.addOptimizerVariableInitializer(betaOnePowerInit)
114+
graph.addOptimizerVariableInitializer(betaOnePowerInit2)
111115
graph.addOptimizerVariable(betaOnePower)
112116

113117
return targets
@@ -127,20 +131,6 @@ public class Adamax(
127131
return firstMoment to secondMoment
128132
}
129133

130-
override fun createSlots(graph: KGraph, tf: Ops, variables: List<Output<Float>>): List<Variable<Float>> {
131-
betaOnePower = tf.withName(FIRST_BETA_POWER_NAME).variable(Shape.scalar(), getDType())
132-
val betaOnePowerAssignName = defaultAssignOpName(FIRST_BETA_POWER_NAME)
133-
134-
val betaOnePowerInit: Assign<*> = tf.withName(betaOnePowerAssignName)
135-
.assign(
136-
betaOnePower,
137-
tf.withName(defaultInitializerOpName(FIRST_BETA_POWER_NAME)).constant(beta1, getDType())
138-
)
139-
graph.addOptimizerVariableInitializer(betaOnePowerInit)
140-
141-
return variables.flatMap { createAdamaxSlot(graph, tf, it.asOutput()).toList() }
142-
}
143-
144134
override val optimizerName: String get() = "Adamax"
145135

146136
override val isRunningOnGPU: Boolean get() = false

tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Ftrl.kt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,8 @@ public class Ftrl(
9191
learningRatePowerConst = tf.constant(learningRatePower, getDType())
9292

9393
for ((i, variable) in weights.withIndex()) {
94-
val varName = variable.ref().op().name()
94+
val (accumSlot, linearSlot) = createFtrlSlot(graph, tf, variable.asOutput())
9595

96-
val accumSlot: Variable<Float> = getSlot(varName, ACCUMULATOR)
97-
val linearSlot: Variable<Float> = getSlot(varName, LINEAR_ACCUMULATOR)
9896
val options = ApplyFtrl.useLocking(true)
9997

10098
targets.add(
@@ -130,10 +128,6 @@ public class Ftrl(
130128
return accumulator to linearAccumulator
131129
}
132130

133-
override fun createSlots(graph: KGraph, tf: Ops, variables: List<Output<Float>>): List<Variable<Float>> {
134-
return variables.flatMap { createFtrlSlot(graph, tf, it.asOutput()).toList() }
135-
}
136-
137131
override val optimizerName: String get() = "Ftrl"
138132

139133
override val isRunningOnGPU: Boolean get() = false

tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Momentum.kt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public class Momentum(
5050
momentumConst = tf.constant(momentum)
5151

5252
for ((i, variable) in weights.withIndex()) {
53-
val slot = getSlot(variable.ref().op().name(), MOMENTUM)
53+
val slot = createMomentumSlot(graph, tf, variable.asOutput())
5454

5555
targets.add(
5656
tf.train.applyMomentum(
@@ -74,10 +74,6 @@ public class Momentum(
7474
return createSlot(graph, tf, v.asOutput(), MOMENTUM, initializer)
7575
}
7676

77-
override fun createSlots(graph: KGraph, tf: Ops, variables: List<Output<Float>>): List<Variable<Float>> {
78-
return variables.map { createMomentumSlot(graph, tf, it.asOutput()) }
79-
}
80-
8177
override val optimizerName: String get() = "Momentum"
8278

8379
override val isRunningOnGPU: Boolean get() = true

tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/Optimizer.kt

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ public abstract class Optimizer(public val clipGradient: ClipGradientAction) {
4646
slots = mutableMapOf()
4747

4848
val gradients: Gradients = computeGradients(tf, loss, weights)
49-
50-
createSlots(graph, tf, weights.map { it.asOutput() }) // empty action if not overridden
51-
5249
return applyGradients(graph, tf, weights, gradients)
5350
}
5451

@@ -77,13 +74,6 @@ public abstract class Optimizer(public val clipGradient: ClipGradientAction) {
7774
return tf.gradients(loss, weights)
7875
}
7976

80-
/**
81-
* No-op slot creation method.
82-
*
83-
* @param variables The variables to create slots for.
84-
*/
85-
protected open fun createSlots(graph: KGraph, tf: Ops, variables: List<Output<Float>>): List<Variable<Float>> = emptyList()
86-
8777
/** Returns optimizer name. */
8878
public abstract val optimizerName: String
8979

tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/optimizer/RMSProp.kt

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,12 @@ public class RMSProp(
6565
epsilonConstant = tf.constant(epsilon, getDType())
6666

6767
for ((i, variable) in weights.withIndex()) {
68-
val varName = variable.ref().op().name()
69-
70-
val rmsSlot: Variable<Float> = getSlot(varName, RMS)
71-
val momentumSlot: Variable<Float> = getSlot(varName, MOMENTUM)
68+
val slots = createRMSPropSlot(graph, tf, variable.asOutput())
69+
val rmsSlot: Variable<Float> = slots[0]
70+
val momentumSlot: Variable<Float> = slots[1]
7271

7372
if (centered) {
74-
val mgSlot: Variable<Float> = getSlot(varName, MG)
73+
val mgSlot: Variable<Float> = slots[2]
7574
targets.add(
7675
tf.train.applyCenteredRmsProp(
7776
variable,
@@ -130,10 +129,6 @@ public class RMSProp(
130129
return listOf(rms, momentum)
131130
}
132131

133-
override fun createSlots(graph: KGraph, tf: Ops, variables: List<Output<Float>>): List<Variable<Float>> {
134-
return variables.flatMap { createRMSPropSlot(graph, tf, it.asOutput()) }
135-
}
136-
137132
override val optimizerName: String get() = "RMSProp"
138133

139134
override val isRunningOnGPU: Boolean get() = true

0 commit comments

Comments
 (0)