Skip to content

Commit 62ab6d1

Browse files
committed
Example, clean up session api
Signed-off-by: Ryan Nett <[email protected]>
1 parent cffce11 commit 62ab6d1

File tree

3 files changed

+110
-49
lines changed

3 files changed

+110
-49
lines changed

tensorflow-core-kotlin/tensorflow-core-kotlin-api/pom.xml

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474

7575
<build>
7676
<sourceDirectory>${project.basedir}/src/main/kotlin</sourceDirectory>
77-
<!-- <testSourceDirectory>${project.basedir}/src/test/kotlin</testSourceDirectory>-->
77+
<testSourceDirectory>${project.basedir}/src/test/kotlin</testSourceDirectory>
7878
<plugins>
7979
<plugin>
8080
<groupId>org.codehaus.mojo</groupId>
@@ -217,42 +217,42 @@
217217
<!-- additional 3rd party ruleset(s) can be specified here -->
218218
</dependencies>
219219
</plugin>
220-
<plugin>
221-
<groupId>org.apache.maven.plugins</groupId>
222-
<artifactId>maven-compiler-plugin</artifactId>
223-
<version>3.5.1</version>
224-
<configuration>
225-
<proc>none</proc>
226-
<source>1.6</source>
227-
<target>1.6</target>
228-
</configuration>
229-
<executions>
230-
<!-- Replacing default-compile as it is treated specially by maven -->
231-
<execution>
232-
<id>default-compile</id>
233-
<phase>none</phase>
234-
</execution>
235-
<!-- Replacing default-testCompile as it is treated specially by maven -->
236-
<execution>
237-
<id>default-testCompile</id>
238-
<phase>none</phase>
239-
</execution>
240-
<execution>
241-
<id>java-compile</id>
242-
<phase>compile</phase>
243-
<goals>
244-
<goal>compile</goal>
245-
</goals>
246-
</execution>
247-
<execution>
248-
<id>java-test-compile</id>
249-
<phase>test-compile</phase>
250-
<goals>
251-
<goal>testCompile</goal>
252-
</goals>
253-
</execution>
254-
</executions>
255-
</plugin>
220+
<!-- <plugin>-->
221+
<!-- <groupId>org.apache.maven.plugins</groupId>-->
222+
<!-- <artifactId>maven-compiler-plugin</artifactId>-->
223+
<!-- <version>3.5.1</version>-->
224+
<!-- <configuration>-->
225+
<!-- <proc>none</proc>-->
226+
<!-- <source>1.6</source>-->
227+
<!-- <target>1.6</target>-->
228+
<!-- </configuration>-->
229+
<!-- <executions>-->
230+
<!-- &lt;!&ndash; Replacing default-compile as it is treated specially by maven &ndash;&gt;-->
231+
<!-- <execution>-->
232+
<!-- <id>default-compile</id>-->
233+
<!-- <phase>none</phase>-->
234+
<!-- </execution>-->
235+
<!-- &lt;!&ndash; Replacing default-testCompile as it is treated specially by maven &ndash;&gt;-->
236+
<!-- <execution>-->
237+
<!-- <id>default-testCompile</id>-->
238+
<!-- <phase>none</phase>-->
239+
<!-- </execution>-->
240+
<!-- <execution>-->
241+
<!-- <id>java-compile</id>-->
242+
<!-- <phase>compile</phase>-->
243+
<!-- <goals>-->
244+
<!-- <goal>compile</goal>-->
245+
<!-- </goals>-->
246+
<!-- </execution>-->
247+
<!-- <execution>-->
248+
<!-- <id>java-test-compile</id>-->
249+
<!-- <phase>test-compile</phase>-->
250+
<!-- <goals>-->
251+
<!-- <goal>testCompile</goal>-->
252+
<!-- </goals>-->
253+
<!-- </execution>-->
254+
<!-- </executions>-->
255+
<!-- </plugin>-->
256256
<!-- <plugin>-->
257257
<!-- <groupId>org.bytedeco</groupId>-->
258258
<!-- <artifactId>javacpp</artifactId>-->

tensorflow-core-kotlin/tensorflow-core-kotlin-api/src/main/kotlin/org/tensorflow/SessionHelpers.kt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public inline fun <R> Session.kotlinRunner(options: RunOptions? = null, block: K
4242
}
4343

4444
public fun Session.kotlinRunner(
45-
feeds: Map<String, Tensor<*>>,
45+
feeds: Map<String, Tensor<*>> = emptyMap(),
4646
fetches: List<String> = emptyList(),
4747
options: RunOptions? = null
4848
): KotlinRunner = kotlinRunner(options).apply {
@@ -52,7 +52,7 @@ public fun Session.kotlinRunner(
5252

5353
@JvmName("kotlinRunnerOutput")
5454
public fun Session.kotlinRunner(
55-
feeds: Map<Output<*>, Tensor<*>>,
55+
feeds: Map<Output<*>, Tensor<*>> = emptyMap(),
5656
fetches: List<Output<*>> = emptyList(),
5757
options: RunOptions? = null
5858
): KotlinRunner = kotlinRunner(options).apply {
@@ -62,7 +62,7 @@ public fun Session.kotlinRunner(
6262

6363
@JvmName("kotlinRunnerOperand")
6464
public fun Session.kotlinRunner(
65-
feeds: Map<Operand<*>, Tensor<*>>,
65+
feeds: Map<Operand<*>, Tensor<*>> = emptyMap(),
6666
fetches: List<Operand<*>> = emptyList(),
6767
options: RunOptions? = null
6868
): KotlinRunner = kotlinRunner(options).apply {
@@ -71,7 +71,7 @@ public fun Session.kotlinRunner(
7171
}
7272

7373
public inline fun <R> Session.kotlinRunner(
74-
feeds: Map<String, Tensor<*>>,
74+
feeds: Map<String, Tensor<*>> = emptyMap(),
7575
fetches: List<String> = emptyList(),
7676
options: RunOptions? = null,
7777
block: KotlinRunner.() -> R
@@ -82,7 +82,7 @@ public inline fun <R> Session.kotlinRunner(
8282

8383
@JvmName("kotlinRunnerOutput")
8484
public inline fun <R> Session.kotlinRunner(
85-
feeds: Map<Output<*>, Tensor<*>>,
85+
feeds: Map<Output<*>, Tensor<*>> = emptyMap(),
8686
fetches: List<Output<*>> = emptyList(),
8787
options: RunOptions? = null,
8888
block: KotlinRunner.() -> R
@@ -93,7 +93,7 @@ public inline fun <R> Session.kotlinRunner(
9393

9494
@JvmName("kotlinRunnerOperand")
9595
public inline fun <R> Session.kotlinRunner(
96-
feeds: Map<Operand<*>, Tensor<*>>,
96+
feeds: Map<Operand<*>, Tensor<*>> = emptyMap(),
9797
fetches: List<Operand<*>> = emptyList(),
9898
options: RunOptions? = null,
9999
block: KotlinRunner.() -> R
@@ -104,22 +104,22 @@ public inline fun <R> Session.kotlinRunner(
104104

105105
// TODO return Map or KotlinRun?
106106
public fun Session.run(
107-
feeds: Map<String, Tensor<*>>,
108-
fetches: List<String>,
107+
feeds: Map<String, Tensor<*>> = emptyMap(),
108+
fetches: List<String> = emptyList(),
109109
options: RunOptions? = null
110110
): KotlinRunner.Run = kotlinRunner(feeds, fetches, options).run()
111111

112112
@JvmName("runOutput")
113113
public fun Session.run(
114-
feeds: Map<Output<*>, Tensor<*>>,
115-
fetches: List<Output<*>>,
114+
feeds: Map<Output<*>, Tensor<*>> = emptyMap(),
115+
fetches: List<Output<*>> = emptyList(),
116116
options: RunOptions? = null
117117
): KotlinRunner.Run = kotlinRunner(feeds, fetches, options).run()
118118

119119
@JvmName("runOperand")
120120
public fun Session.run(
121-
feeds: Map<Operand<*>, Tensor<*>>,
122-
fetches: List<Operand<*>>,
121+
feeds: Map<Operand<*>, Tensor<*>> = emptyMap(),
122+
fetches: List<Operand<*>> = emptyList(),
123123
options: RunOptions? = null
124124
): KotlinRunner.Run = kotlinRunner(feeds, fetches, options).run()
125125

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow
18+
19+
import org.junit.jupiter.api.Test
20+
import org.tensorflow.ndarray.Shape
21+
import org.tensorflow.ndarray.get
22+
import org.tensorflow.op.kotlin.KotlinOps
23+
import org.tensorflow.op.kotlin.tf
24+
import org.tensorflow.op.kotlin.withSubScope
25+
import org.tensorflow.types.TFloat32
26+
27+
public fun KotlinOps.DenseLayer(
28+
name: String,
29+
x: Operand<TFloat32>,
30+
n: Int,
31+
activation: KotlinOps.(Operand<TFloat32>) -> Operand<TFloat32> = { tf.nn.relu(it) }
32+
): Operand<TFloat32> = tf.withSubScope(name) {
33+
val inputDims = x.shape()[1]
34+
val W = tf.variable(tf.math.add(tf.zeros(tf.array(inputDims.toInt(), n), TFloat32.DTYPE), constant(1f)))
35+
val b = tf.variable(tf.math.add(tf.zeros(tf.array(n), TFloat32.DTYPE), constant(1f)))
36+
activation(tf.math.add(tf.linalg.matMul(x, W), b))
37+
}
38+
39+
public class Example {
40+
@Test
41+
private fun mnistExample() {
42+
Graph {
43+
val input = tf.placeholderWithDefault(
44+
tf.math.add(tf.zeros(tf.array(1, 28, 28, 3), TFloat32.DTYPE), tf.constant(1f)),
45+
Shape.of(-1, 28, 28, 3)
46+
)
47+
48+
val output = with(tf) {
49+
var x: Operand<TFloat32> = tf.reshape(input, tf.array(-1))
50+
x = DenseLayer("Layer1", x, 256)
51+
x = DenseLayer("Layer2", x, 64)
52+
DenseLayer("OutputLayer", x, 10) { tf.math.sigmoid(x) }
53+
}
54+
55+
withSession {
56+
val outputValue = it.run(fetches = listOf(output))[output]
57+
println(outputValue.data())
58+
}
59+
}
60+
}
61+
}

0 commit comments

Comments
 (0)