Skip to content

Commit 00f8123

Browse files
committed
WIP Session/Runner API. Java API needs updates
Signed-off-by: Ryan Nett <[email protected]>
1 parent 425ac21 commit 00f8123

File tree

1 file changed

+279
-0
lines changed
  • tensorflow-core-kotlin/tensorflow-core-kotlin-api/src/main/kotlin/org/tensorflow

1 file changed

+279
-0
lines changed
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
package org.tensorflow
16+
17+
import org.tensorflow.ndarray.Shape
18+
import org.tensorflow.op.Op
19+
import org.tensorflow.op.kotlin.tf
20+
import org.tensorflow.proto.framework.RunOptions
21+
import org.tensorflow.types.TInt32
22+
import org.tensorflow.types.family.TType
23+
import kotlin.contracts.InvocationKind
24+
import kotlin.contracts.contract
25+
import kotlin.reflect.KProperty
26+
27+
internal sealed class FetchSpec {
28+
data class OperationFetch(val operation: String, val index: Int?) : FetchSpec()
29+
data class OperandFetch(val operand: Operand<*>) : FetchSpec()
30+
data class OutputFetch(val output: Output<*>) : FetchSpec()
31+
32+
companion object {
33+
operator fun invoke(operation: String) = OperationFetch(operation, null)
34+
operator fun invoke(operation: String, index: Int) = OperationFetch(operation, index)
35+
operator fun invoke(operand: Operand<*>) = OperandFetch(operand)
36+
operator fun invoke(output: Output<*>) = OutputFetch(output)
37+
}
38+
}
39+
40+
public fun Session.kotlinRunner(options: RunOptions? = null): KotlinRunner = KotlinRunner(this, options)
41+
42+
public inline fun <R> Session.kotlinRunner(options: RunOptions? = null, block: KotlinRunner.() -> R): R {
43+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
44+
return kotlinRunner(options).run(block)
45+
}
46+
47+
public fun Session.kotlinRunner(feeds: Map<String, Tensor<*>>, fetches: List<String> = emptyList(), options: RunOptions? = null): KotlinRunner = kotlinRunner(options).apply {
48+
feed(feeds)
49+
fetch(fetches)
50+
}
51+
52+
@JvmName("kotlinRunnerOutput")
53+
public fun Session.kotlinRunner(feeds: Map<Output<*>, Tensor<*>>, fetches: List<Output<*>> = emptyList(), options: RunOptions? = null): KotlinRunner = kotlinRunner(options).apply {
54+
feed(feeds)
55+
fetch(fetches)
56+
}
57+
58+
@JvmName("kotlinRunnerOperand")
59+
public fun Session.kotlinRunner(feeds: Map<Operand<*>, Tensor<*>>, fetches: List<Operand<*>> = emptyList(), options: RunOptions? = null): KotlinRunner = kotlinRunner(options).apply {
60+
feed(feeds)
61+
fetch(fetches)
62+
}
63+
64+
public inline fun <R> Session.kotlinRunner(feeds: Map<String, Tensor<*>>, fetches: List<String> = emptyList(), options: RunOptions? = null, block: KotlinRunner.() -> R): R {
65+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
66+
return kotlinRunner(feeds, fetches, options).run(block)
67+
}
68+
69+
@JvmName("kotlinRunnerOutput")
70+
public inline fun <R> Session.kotlinRunner(feeds: Map<Output<*>, Tensor<*>>, fetches: List<Output<*>> = emptyList(), options: RunOptions? = null, block: KotlinRunner.() -> R): R {
71+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
72+
return kotlinRunner(feeds, fetches, options).run(block)
73+
}
74+
75+
@JvmName("kotlinRunnerOperand")
76+
public inline fun <R> Session.kotlinRunner(feeds: Map<Operand<*>, Tensor<*>>, fetches: List<Operand<*>> = emptyList(), options: RunOptions? = null, block: KotlinRunner.() -> R): R {
77+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
78+
return kotlinRunner(feeds, fetches, options).run(block)
79+
}
80+
81+
//TODO return Map or KotlinRun?
82+
public fun Session.run(feeds: Map<String, Tensor<*>>, fetches: List<String>, options: RunOptions? = null): KotlinRunner.Run = kotlinRunner(feeds, fetches, options).run()
83+
84+
@JvmName("runOutput")
85+
public fun Session.run(feeds: Map<Output<*>, Tensor<*>>, fetches: List<Output<*>>, options: RunOptions? = null): KotlinRunner.Run = kotlinRunner(feeds, fetches, options).run()
86+
87+
@JvmName("runOperand")
88+
public fun Session.run(feeds: Map<Operand<*>, Tensor<*>>, fetches: List<Operand<*>>, options: RunOptions? = null): KotlinRunner.Run = kotlinRunner(feeds, fetches, options).run()
89+
90+
public class KotlinRunner internal constructor(private val session: Session, options: RunOptions?) {
91+
private val runner = session.runner().let {
92+
if(options != null)
93+
it.setOptions(options)
94+
else
95+
it
96+
}
97+
98+
// feeding
99+
100+
public fun feed(operation: String, t: Tensor<*>){
101+
runner.feed(operation, t)
102+
}
103+
104+
public fun feed(operation: String, index: Int, t: Tensor<*>){
105+
runner.feed(operation, index, t)
106+
}
107+
108+
public fun <T: TType> feed(operand: Operand<in T>, t: Tensor<in T>){
109+
runner.feed(operand, t)
110+
}
111+
112+
public fun <T: TType> feed(output: Output<in T>, t: Tensor<in T>){
113+
runner.feed(output, t)
114+
}
115+
116+
public fun feed(vararg operations: Pair<String, Tensor<*>>): Unit = operations.forEach { feed(it.first, it.second) }
117+
118+
@JvmName("feedOperands")
119+
public fun feed(vararg operands: Pair<Operand<*>, Tensor<*>>): Unit = operands.forEach { feed(it.first, it.second) }
120+
121+
@JvmName("feedOutputs")
122+
public fun feed(vararg operands: Pair<Output<*>, Tensor<*>>): Unit = operands.forEach { feed(it.first, it.second) }
123+
124+
public fun feed(operations: Map<String, Tensor<*>>): Unit = operations.forEach { feed(it.key, it.value) }
125+
126+
@JvmName("feedOperands")
127+
public fun feed(operands: Map<Operand<*>, Tensor<*>>): Unit = operands.forEach { feed(it.key, it.value) }
128+
129+
@JvmName("feedOutputs")
130+
public fun feed(operands: Map<Output<*>, Tensor<*>>): Unit = operands.forEach { feed(it.key, it.value) }
131+
132+
@JvmName("operandFeed")
133+
public fun <T: TType> Operand<T>.feed(t: Tensor<T>): Unit = feed(this, t)
134+
135+
@JvmName("outputFeed")
136+
public fun <T: TType> Output<T>.feed(t: Tensor<T>): Unit = feed(this, t)
137+
138+
public operator fun set(operation: String, t: Tensor<*>): Unit = feed(operation, t)
139+
140+
public operator fun set(operation: String, index: Int, t: Tensor<*>): Unit = feed(operation, index, t)
141+
142+
public operator fun <T: TType> set(operand: Operand<T>, t: Tensor<T>): Unit = feed(operand, t)
143+
144+
public operator fun <T: TType> set(output: Output<T>, t: Tensor<T>): Unit = feed(output, t)
145+
146+
// targeting
147+
148+
public fun addTarget(operation: String){
149+
runner.addTarget(operation)
150+
}
151+
152+
public fun addTarget(operation: Operation){
153+
runner.addTarget(operation)
154+
}
155+
156+
public fun addTarget(op: Op){
157+
runner.addTarget(op)
158+
}
159+
160+
// fetching
161+
162+
public inner class FetchKey<T : TType> internal constructor(public val index: Int)
163+
164+
private var currentKey = 0
165+
private val fetchMap = mutableMapOf<FetchSpec, FetchKey<*>>()
166+
167+
private fun <T : TType> newKey(spec: FetchSpec): FetchKey<T> {
168+
if(spec in fetchMap)
169+
return fetchMap[spec] as FetchKey<T>
170+
171+
return FetchKey<T>(currentKey++).also { fetchMap[spec] = it }
172+
}
173+
174+
public fun findKey(operation: String): FetchKey<*> = fetchMap[FetchSpec(operation)] ?: error("Operation $operation was not fetched")
175+
public fun findKey(operation: String, index: Int): FetchKey<*> = fetchMap[FetchSpec(operation, index)] ?: error("Index $index of Operation $operation was not fetched")
176+
public fun <T : TType> findKey(operand: Operand<T>): FetchKey<T> = fetchMap[FetchSpec(operand)] as? FetchKey<T>? ?: error("Operand $operand was not fetched")
177+
public fun <T : TType> findKey(output: Output<T>): FetchKey<T> = fetchMap[FetchSpec(output)] as? FetchKey<T>? ?: error("Output $output was not fetched")
178+
179+
public fun fetch(operation: String): FetchKey<*> =
180+
newKey<TType>(FetchSpec(operation)).also { runner.fetch(operation) }
181+
182+
public fun fetch(operation: String, index: Int): FetchKey<*> =
183+
newKey<TType>(FetchSpec(operation, index)).also { runner.fetch(operation, index) }
184+
185+
public fun <T : TType> fetch(output: Output<T>): FetchKey<*> =
186+
newKey<TType>(FetchSpec(output)).also { runner.fetch(output) }
187+
188+
public fun <T : TType> fetch(operand: Operand<T>): FetchKey<*> =
189+
newKey<TType>(FetchSpec(operand)).also { runner.fetch(operand) }
190+
191+
public fun fetch(vararg operations: String): List<FetchKey<*>> = operations.map { fetch(it) }
192+
193+
public fun fetch(vararg outputs: Output<*>): List<FetchKey<*>> = outputs.map { fetch(it) }
194+
195+
public fun fetch(vararg operands: Operand<*>): List<FetchKey<*>> = operands.map { fetch(it) }
196+
197+
@JvmName("fetchStrings")
198+
public fun fetch(operations: List<String>): List<FetchKey<*>> = operations.map { fetch(it) }
199+
200+
@JvmName("fetchOutputs")
201+
public fun fetch(outputs: List<Output<*>>): List<FetchKey<*>> = outputs.map { fetch(it) }
202+
203+
@JvmName("fetchOperands")
204+
public fun fetch(operands: List<Operand<*>>): List<FetchKey<*>> = operands.map { fetch(it) }
205+
206+
// running
207+
208+
public inner class Run internal constructor(public val output: List<Tensor<*>>): AutoCloseable {
209+
public operator fun <T : TType> get(key: FetchKey<T>): Tensor<T> {
210+
if (key.index < 0 || key.index > output.lastIndex)
211+
error("Invalid key: key's index is ${key.index}, but there are only ${output.size} outputs.")
212+
return output[key.index] as Tensor<T>
213+
}
214+
215+
public operator fun get(operation: String): Tensor<*> = this[findKey(operation)]
216+
public operator fun get(operation: String, index: Int): Tensor<*> = this[findKey(operation, index)]
217+
public operator fun <T: TType> get(output: Output<T>): Tensor<T> = this[findKey(output)]
218+
public operator fun <T: TType> get(operand: Operand<T>): Tensor<T> = this[findKey(operand)]
219+
220+
@JvmName("keyGet")
221+
public fun <T: TType> FetchKey<T>.get(): Tensor<T> = this@Run[this]
222+
223+
@JvmName("operandGet")
224+
public fun <T: TType> Operand<T>.get(): Tensor<T> = this@Run[this]
225+
226+
@JvmName("outputGet")
227+
public fun <T: TType> Output<T>.get(): Tensor<T> = this@Run[this]
228+
229+
public operator fun <T: TType> FetchKey<T>.getValue(thisRef: Any?, property: KProperty<*>): Tensor<T> = this.get()
230+
231+
override fun close() {
232+
output.forEach { it.close() }
233+
}
234+
}
235+
236+
private var latestRun: Run? = null
237+
238+
public fun run(): Run = Run(runner.run()).also {
239+
latestRun = it
240+
}
241+
242+
public fun <R> run(freeTensors: Boolean = true, block: Run.() -> R): R {
243+
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
244+
return if(freeTensors) run().use(block) else run().run(block)
245+
}
246+
247+
//TODO Unsure if the nicer API is worth the weird run call requirements
248+
public operator fun <T: TType> FetchKey<T>.getValue(thisRef: Any?, property: KProperty<*>): Tensor<T> = latestRun?.get(this) ?: error("Runner has not yet been ran, can not get fetched value.")
249+
}
250+
251+
252+
public fun test() {
253+
Graph {
254+
with(tf) {
255+
val a = placeholder(TInt32.DTYPE, Shape.of(1))
256+
val b = constant(2)
257+
val c = math.add(a, b)
258+
259+
withSession {
260+
val aIn = Tensor.of(TInt32.DTYPE, Shape.of(1))
261+
262+
it.kotlinRunner{
263+
this[a] = aIn
264+
265+
val cOut by fetch(c)
266+
267+
run {
268+
val cOut2 = this[c]
269+
cOut
270+
}
271+
}
272+
273+
val cOut = it.run(mapOf(a to aIn), listOf(c))[c]
274+
275+
}
276+
}
277+
278+
}
279+
}

0 commit comments

Comments
 (0)