Skip to content

Commit 0834037

Browse files
committed
Generate name_map
Signed-off-by: Ryan Nett <[email protected]>
1 parent 0fe34e6 commit 0834037

File tree

7 files changed

+147
-36
lines changed

7 files changed

+147
-36
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE
2+
3+
package org.tensorflow.internal.c_api;
4+
5+
import org.bytedeco.javacpp.BytePointer;
6+
import org.bytedeco.javacpp.Loader;
7+
import org.bytedeco.javacpp.Pointer;
8+
import org.bytedeco.javacpp.annotation.ByRef;
9+
import org.bytedeco.javacpp.annotation.ByVal;
10+
import org.bytedeco.javacpp.annotation.Const;
11+
import org.bytedeco.javacpp.annotation.Index;
12+
import org.bytedeco.javacpp.annotation.MemberGetter;
13+
import org.bytedeco.javacpp.annotation.Name;
14+
import org.bytedeco.javacpp.annotation.NoOffset;
15+
import org.bytedeco.javacpp.annotation.Properties;
16+
import org.bytedeco.javacpp.annotation.StdString;
17+
18+
@Name("std::unordered_map<tensorflow::string,tensorflow::Node*>")
19+
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
20+
public class NameMap extends Pointer {
21+
22+
static {
23+
Loader.load();
24+
}
25+
26+
/**
27+
* Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}.
28+
*/
29+
public NameMap(Pointer p) {
30+
super(p);
31+
}
32+
33+
public NameMap() {
34+
allocate();
35+
}
36+
37+
private native void allocate();
38+
39+
public native @Name("operator =")
40+
@ByRef
41+
NameMap put(@ByRef NameMap x);
42+
43+
public boolean empty() {
44+
return size() == 0;
45+
}
46+
47+
public native long size();
48+
49+
@Index
50+
public native Node get(@StdString BytePointer i);
51+
52+
public native NameMap put(@StdString BytePointer i, Node value);
53+
54+
public native @ByVal
55+
Iterator begin();
56+
57+
public native @ByVal
58+
Iterator end();
59+
60+
@NoOffset
61+
@Name("iterator")
62+
public static class Iterator extends Pointer {
63+
64+
public Iterator(Pointer p) {
65+
super(p);
66+
}
67+
68+
public Iterator() {
69+
}
70+
71+
public native @Name("operator ++")
72+
@ByRef
73+
Iterator increment();
74+
75+
public native @Name("operator ==")
76+
boolean equals(@ByRef Iterator it);
77+
78+
public native @Name("operator *().first")
79+
@MemberGetter
80+
@StdString
81+
BytePointer first();
82+
83+
public native @Name("operator *().second")
84+
@MemberGetter
85+
@Const
86+
Node second();
87+
}
88+
}
89+

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_Graph.java

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,43 @@
22

33
package org.tensorflow.internal.c_api;
44

5-
import java.nio.*;
6-
import org.bytedeco.javacpp.*;
7-
import org.bytedeco.javacpp.annotation.*;
8-
9-
import static org.tensorflow.internal.c_api.global.tensorflow.*;
5+
import org.bytedeco.javacpp.Loader;
6+
import org.bytedeco.javacpp.Pointer;
7+
import org.bytedeco.javacpp.annotation.ByRef;
8+
import org.bytedeco.javacpp.annotation.MemberGetter;
9+
import org.bytedeco.javacpp.annotation.NoOffset;
10+
import org.bytedeco.javacpp.annotation.Properties;
1011

1112
// Parsed from tensorflow/c/c_api_internal.h
1213

13-
@NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
14+
@NoOffset
15+
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
1416
public class TF_Graph extends org.tensorflow.internal.c_api.AbstractTF_Graph {
15-
static { Loader.load(); }
16-
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
17-
public TF_Graph(Pointer p) { super(p); }
17+
18+
static {
19+
Loader.load();
20+
}
21+
22+
/**
23+
* Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}.
24+
*/
25+
public TF_Graph(Pointer p) {
26+
super(p);
27+
}
1828

1929

20-
21-
public native @MemberGetter @ByRef NativeGraphPointer graph();
30+
public native @MemberGetter
31+
@ByRef
32+
NativeGraphPointer graph();
2233

2334
// Runs shape inference.
24-
35+
2536

2637
// Maps from name of an operation to the Node* in 'graph'.
27-
38+
public native @ByRef
39+
NameMap name_map();
40+
41+
public native TF_Graph name_map(NameMap setter);
2842

2943
// The keys of this map are all the active sessions using this graph. Each
3044
// value records whether the graph has been mutated since the corresponding
@@ -39,11 +53,16 @@ public class TF_Graph extends org.tensorflow.internal.c_api.AbstractTF_Graph {
3953
//
4054
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
4155
// status, this should be reverted when possible.
42-
43-
// set true by TF_DeleteGraph
56+
57+
// set true by TF_DeleteGraph
4458

4559
// Used to link graphs contained in TF_WhileParams to the parent graph that
4660
// will eventually contain the full while loop.
47-
public native TF_Graph parent(); public native TF_Graph parent(TF_Graph setter);
48-
public native TF_Output parent_inputs(); public native TF_Graph parent_inputs(TF_Output setter);
61+
public native TF_Graph parent();
62+
63+
public native TF_Graph parent(TF_Graph setter);
64+
65+
public native TF_Output parent_inputs();
66+
67+
public native TF_Graph parent_inputs(TF_Output setter);
4968
}

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ public class tensorflow extends org.tensorflow.internal.c_api.presets.tensorflow
7878

7979
// Targeting ../NativeOutputVector.java
8080

81+
// Targeting ../NameMap.java
82+
8183
// Parsed from tensorflow/core/platform/ctstring_internal.h
8284

8385
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GradientAdapterHelpers.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,6 @@ public static GraphOperation getGraphOp(Graph g, Node node) {
7878
}
7979

8080
public static void useDangerousLockedBuilders(Graph g, boolean dangerous) {
81-
g.setDangerousOpBuilder(dangerous);
81+
g.setDangerousGradientBuilder(dangerous);
8282
}
8383
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ public GraphOperationBuilder opBuilder(String type, String name) {
380380
if (!isOpEnabled(type)) {
381381
throw new IllegalArgumentException("Op " + type + " is not valid in graph mode.");
382382
}
383-
return new GraphOperationBuilder(this, type, name, dangerousOpBuilder);
383+
return new GraphOperationBuilder(this, type, name, dangerousGradientBuilder);
384384
}
385385

386386
@Override
@@ -711,12 +711,12 @@ synchronized SaverDef saverDef() {
711711
private int refcount = 0;
712712
private SaverDef saverDef;
713713

714-
private boolean dangerousOpBuilder;
714+
private boolean dangerousGradientBuilder;
715715

716716
private final List<Op> initializers = new ArrayList<>();
717717

718-
synchronized void setDangerousOpBuilder(boolean dangerous) {
719-
dangerousOpBuilder = dangerous;
718+
synchronized void setDangerousGradientBuilder(boolean dangerous) {
719+
dangerousGradientBuilder = dangerous;
720720
}
721721

722722
// Related native objects (such as the TF_Operation object backing an Operation instance)

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@
6161
*/
6262
public final class GraphOperationBuilder implements OperationBuilder {
6363

64-
GraphOperationBuilder(Graph graph, String type, String name, boolean dangerousBuilder) {
64+
GraphOperationBuilder(Graph graph, String type, String name, boolean dangerousGradientBuilder) {
6565
this.graph = graph;
66-
this.dangerousBuilder = dangerousBuilder;
66+
this.dangerousGradientBuilder = dangerousGradientBuilder;
6767
Graph.Reference r = graph.ref();
6868
try {
69-
if (dangerousBuilder) {
70-
this.unsafeNativeHandle = allocateDangerous(r.nativeHandle(), type, name);
69+
if (dangerousGradientBuilder) {
70+
this.unsafeNativeHandle = allocateDangerousGradient(r.nativeHandle(), type, name);
7171
} else {
7272
this.unsafeNativeHandle = allocate(r.nativeHandle(), type, name);
7373
}
@@ -83,19 +83,16 @@ public final class GraphOperationBuilder implements OperationBuilder {
8383
*/
8484
@Override
8585
public GraphOperation build() {
86-
Graph.Reference r = graph.ref();
87-
try {
86+
try (Graph.Reference r = graph.ref()) {
8887
TF_Operation built;
89-
if (dangerousBuilder) {
90-
built = finishDangerous(unsafeNativeHandle);
88+
if (dangerousGradientBuilder) {
89+
built = finishDangerousGradient(r.nativeHandle(), unsafeNativeHandle);
9190
} else {
9291
built = finish(unsafeNativeHandle);
9392
}
9493
GraphOperation op = new GraphOperation(graph, built);
9594
unsafeNativeHandle = null;
9695
return op;
97-
} finally {
98-
r.close();
9996
}
10097
}
10198

@@ -360,7 +357,7 @@ public GraphOperationBuilder setAttr(String name, String[] value) {
360357

361358
private TF_OperationDescription unsafeNativeHandle;
362359
private final Graph graph;
363-
private final boolean dangerousBuilder;
360+
private final boolean dangerousGradientBuilder;
364361

365362
private static void requireHandle(Pointer handle) {
366363
if (handle == null || handle.isNull()) {
@@ -389,7 +386,8 @@ private static TF_OperationDescription allocate(TF_Graph graphHandle, String typ
389386
return TF_NewOperation(graphHandle, type, name);
390387
}
391388

392-
private static TF_OperationDescription allocateDangerous(TF_Graph graphHandle, String type,
389+
private static TF_OperationDescription allocateDangerousGradient(TF_Graph graphHandle,
390+
String type,
393391
String name) {
394392
if (graphHandle == null || graphHandle.isNull()) {
395393
throw new IllegalStateException("close() has been called on the Graph");
@@ -408,13 +406,14 @@ private static TF_Operation finish(TF_OperationDescription handle) {
408406
}
409407
}
410408

411-
private static TF_Operation finishDangerous(TF_OperationDescription handle) {
409+
private static TF_Operation finishDangerousGradient(TF_Graph g, TF_OperationDescription handle) {
412410
requireHandle(handle);
413411

414412
try (PointerScope scope = new PointerScope()) {
415413
TF_Status status = TF_Status.newStatus();
416414
TF_Operation op = TF_FinishOperationLocked(handle, status);
417415
status.throwExceptionIfNotOK();
416+
// g.name_map().put(TF_OperationName(op), null);
418417
return op;
419418
}
420419
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,10 @@ public void map(InfoMap infoMap) {
282282
.put(new Info("tensorflow::Graph").javaNames("NativeGraphPointer"))
283283
.put(new Info("TF_Graph::graph")
284284
.javaText("public native @MemberGetter @ByRef NativeGraphPointer graph();"))
285-
.put(new Info("TF_Graph::refiner", "TF_Graph::mu", "TF_Graph::name_map",
285+
.put(new Info("TF_Graph::refiner", "TF_Graph::mu",
286286
"TF_Graph::sessions", "TF_Graph::delete_requested").skip())
287+
.put(new Info("std::unordered_map<tensorflow::string,tensorflow::Node*>")
288+
.pointerTypes("NameMap").define())
287289
.put(new Info("TF_ImportGraphDefOptions").pointerTypes("TF_ImportGraphDefOptions")
288290
.base("org.tensorflow.internal.c_api.AbstractTF_ImportGraphDefOptions"))
289291
.put(new Info("TF_WhileParams", "TFE_MonitoringCounterCell", "TFE_MonitoringSamplerCell",

0 commit comments

Comments
 (0)