1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/java/src/main/native/graph_jni.h"
17
18 #include <limits>
19 #include <memory>
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/java/src/main/native/exception_jni.h"
22 #include "tensorflow/java/src/main/native/utils_jni.h"
23
24 namespace {
25 template <class T>
requireHandleImpl(JNIEnv * env,jlong handle)26 T* requireHandleImpl(JNIEnv* env, jlong handle) {
27 static_assert(sizeof(jlong) >= sizeof(T*),
28 "Cannot package C object pointers as a Java long");
29 if (handle == 0) {
30 throwException(env, kIllegalStateException,
31 "close() has been called on the Graph");
32 return nullptr;
33 }
34 return reinterpret_cast<T*>(handle);
35 }
36
requireHandle(JNIEnv * env,jlong handle)37 TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
38 return requireHandleImpl<TF_Graph>(env, handle);
39 }
40
requireOperationHandle(JNIEnv * env,jlong handle)41 TF_Operation* requireOperationHandle(JNIEnv* env, jlong handle) {
42 return requireHandleImpl<TF_Operation>(env, handle);
43 }
44 } // namespace
45
Java_org_tensorflow_Graph_allocate(JNIEnv *,jclass)46 JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_allocate(JNIEnv*, jclass) {
47 return reinterpret_cast<jlong>(TF_NewGraph());
48 }
49
Java_org_tensorflow_Graph_delete(JNIEnv *,jclass,jlong handle)50 JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv*, jclass,
51 jlong handle) {
52 if (handle == 0) return;
53 TF_DeleteGraph(reinterpret_cast<TF_Graph*>(handle));
54 }
55
Java_org_tensorflow_Graph_operation(JNIEnv * env,jclass clazz,jlong handle,jstring name)56 JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
57 jclass clazz,
58 jlong handle,
59 jstring name) {
60 TF_Graph* g = requireHandle(env, handle);
61 if (g == nullptr) return 0;
62 const char* cname = env->GetStringUTFChars(name, nullptr);
63 TF_Operation* op = TF_GraphOperationByName(g, cname);
64 env->ReleaseStringUTFChars(name, cname);
65 return reinterpret_cast<jlong>(op);
66 }
67
Java_org_tensorflow_Graph_nextOperation(JNIEnv * env,jclass clazz,jlong handle,jint position)68 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(
69 JNIEnv* env, jclass clazz, jlong handle, jint position) {
70 TF_Graph* g = requireHandle(env, handle);
71 if (g == nullptr) return nullptr;
72
73 size_t pos = static_cast<size_t>(position);
74 TF_Operation* operation = TF_GraphNextOperation(g, &pos);
75 if (operation == nullptr) return nullptr;
76
77 jlong handle_and_position[2];
78 handle_and_position[0] = reinterpret_cast<jlong>(operation);
79 handle_and_position[1] = static_cast<jlong>(pos);
80
81 jlongArray rhett = env->NewLongArray(2);
82 env->SetLongArrayRegion(rhett, 0, 2, handle_and_position);
83 return rhett;
84 }
85
Java_org_tensorflow_Graph_importGraphDef(JNIEnv * env,jclass clazz,jlong handle,jbyteArray graph_def,jstring prefix)86 JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef(
87 JNIEnv* env, jclass clazz, jlong handle, jbyteArray graph_def,
88 jstring prefix) {
89 TF_Graph* g = requireHandle(env, handle);
90 if (g == nullptr) return;
91
92 TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
93
94 jboolean is_copy;
95 const char* cprefix = env->GetStringUTFChars(prefix, &is_copy);
96 TF_ImportGraphDefOptionsSetPrefix(opts, cprefix);
97 env->ReleaseStringUTFChars(prefix, cprefix);
98
99 static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
100 jbyte* bytes = env->GetByteArrayElements(graph_def, &is_copy);
101 TF_Buffer* buf =
102 TF_NewBufferFromString(bytes, env->GetArrayLength(graph_def));
103 TF_Status* status = TF_NewStatus();
104
105 TF_GraphImportGraphDef(g, buf, opts, status);
106 throwExceptionIfNotOK(env, status);
107 // Continue cleaning up resources even if an exception was thrown.
108
109 TF_DeleteStatus(status);
110 TF_DeleteBuffer(buf);
111 env->ReleaseByteArrayElements(graph_def, bytes, JNI_ABORT);
112
113 TF_DeleteImportGraphDefOptions(opts);
114 }
115
116 JNIEXPORT jbyteArray JNICALL
Java_org_tensorflow_Graph_toGraphDef(JNIEnv * env,jclass clazz,jlong handle)117 Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
118 jbyteArray ret = nullptr;
119 TF_Graph* g = requireHandle(env, handle);
120 if (g == nullptr) return ret;
121
122 TF_Buffer* buf = TF_NewBuffer();
123 TF_Status* status = TF_NewStatus();
124 TF_GraphToGraphDef(g, buf, status);
125 if (throwExceptionIfNotOK(env, status)) {
126 // sizeof(jsize) is less than sizeof(size_t) on some platforms.
127 if (buf->length > std::numeric_limits<jint>::max()) {
128 throwException(env, kIndexOutOfBoundsException,
129 "GraphDef is too large to serialize into a byte[] array");
130 } else {
131 static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
132 jint ret_len = static_cast<jint>(buf->length);
133 ret = env->NewByteArray(ret_len);
134 env->SetByteArrayRegion(ret, 0, ret_len,
135 static_cast<const jbyte*>(buf->data));
136 }
137 }
138 TF_DeleteStatus(status);
139 TF_DeleteBuffer(buf);
140 return ret;
141 }
142
Java_org_tensorflow_Graph_addGradients(JNIEnv * env,jclass clazz,jlong handle,jstring prefix,jlongArray y_handles,jintArray y_indices,jlongArray x_handles,jintArray x_indices,jlongArray dx_handles,jintArray dx_indices)143 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
144 JNIEnv* env, jclass clazz, jlong handle, jstring prefix,
145 jlongArray y_handles, jintArray y_indices, jlongArray x_handles,
146 jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) {
147 TF_Graph* g = requireHandle(env, handle);
148 if (g == nullptr) return nullptr;
149
150 const jint ny = env->GetArrayLength(y_handles);
151 const jint nx = env->GetArrayLength(x_handles);
152
153 std::unique_ptr<TF_Output[]> y(new TF_Output[ny]);
154 std::unique_ptr<TF_Output[]> x(new TF_Output[nx]);
155 std::unique_ptr<TF_Output[]> dx(nullptr);
156 std::unique_ptr<TF_Output[]> dy(new TF_Output[nx]);
157
158 resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny);
159 resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx);
160 if (dx_handles != nullptr) {
161 if (env->GetArrayLength(dx_handles) != ny) {
162 throwException(env, kIllegalArgumentException,
163 "expected %d, got %d dx handles", ny,
164 env->GetArrayLength(dx_handles));
165 }
166 dx.reset(new TF_Output[ny]);
167 resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny);
168 }
169 if (env->ExceptionCheck()) return nullptr;
170
171 const char* cprefix = nullptr;
172 if (prefix != nullptr) {
173 cprefix = env->GetStringUTFChars(prefix, nullptr);
174 }
175 TF_Status* status = TF_NewStatus();
176 TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(),
177 status, dy.get());
178 if (prefix != nullptr) {
179 env->ReleaseStringUTFChars(prefix, cprefix);
180 }
181 if (!throwExceptionIfNotOK(env, status)) {
182 TF_DeleteStatus(status);
183 return nullptr;
184 }
185 TF_DeleteStatus(status);
186
187 // returned array contains both op handles and output indices, in pair
188 jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1);
189 jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr);
190 for (int i = 0, j = nx; i < nx; ++i, ++j) {
191 TF_Output dy_output = dy.get()[i];
192 dy_elems[i] = reinterpret_cast<jlong>(dy_output.oper);
193 dy_elems[j] = static_cast<jlong>(dy_output.index);
194 }
195 env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0);
196
197 return dy_handles_and_indices;
198 }
199
200 // helper function for while loop -- constructs conditional or body subgraph
buildSubgraph(JNIEnv * env,jclass clazz,jobject subgraph_builder,TF_Graph * const subgraph,const TF_Output * const inputs,const TF_Output * const outputs,const int ninputs,const int noutputs)201 jlongArray buildSubgraph(JNIEnv* env, jclass clazz, jobject subgraph_builder,
202 TF_Graph* const subgraph,
203 const TF_Output* const inputs,
204 const TF_Output* const outputs, const int ninputs,
205 const int noutputs) {
206 jmethodID build_subgraph_method_id = env->GetStaticMethodID(
207 clazz, "buildSubgraph",
208 "(Lorg/tensorflow/Graph$WhileSubgraphBuilder;J[J[I[J[I)[J");
209 if (build_subgraph_method_id == nullptr) return nullptr;
210
211 jlong subgraph_handle = reinterpret_cast<jlong>(subgraph);
212
213 jlongArray input_handles = env->NewLongArray(ninputs);
214 jintArray input_indices = env->NewIntArray(ninputs);
215 jlongArray output_handles = env->NewLongArray(noutputs);
216 jintArray output_indices = env->NewIntArray(noutputs);
217
218 jlong* input_handles_elems =
219 env->GetLongArrayElements(input_handles, nullptr);
220 jint* input_indices_elems = env->GetIntArrayElements(input_indices, nullptr);
221 jlong* output_handles_elems =
222 env->GetLongArrayElements(output_handles, nullptr);
223 jint* output_indices_elems =
224 env->GetIntArrayElements(output_indices, nullptr);
225
226 for (int i = 0; i < ninputs; ++i) {
227 input_handles_elems[i] = reinterpret_cast<jlong>((inputs[i]).oper);
228 input_indices_elems[i] = static_cast<jint>((inputs[i]).index);
229 }
230
231 for (int i = 0; i < noutputs; ++i) {
232 output_handles_elems[i] = reinterpret_cast<jlong>((outputs[i]).oper);
233 output_indices_elems[i] = static_cast<jint>((outputs[i]).index);
234 }
235
236 env->ReleaseLongArrayElements(input_handles, input_handles_elems, 0);
237 env->ReleaseIntArrayElements(input_indices, input_indices_elems, 0);
238 env->ReleaseLongArrayElements(output_handles, output_handles_elems, 0);
239 env->ReleaseIntArrayElements(output_indices, output_indices_elems, 0);
240
241 // call Java code to construct the subgraph
242 jlongArray output_handles_and_indices =
243 (jlongArray)env->CallStaticObjectMethod(
244 clazz, build_subgraph_method_id, subgraph_builder, subgraph_handle,
245 input_handles, input_indices, output_handles, output_indices);
246
247 if (env->ExceptionOccurred()) {
248 env->ExceptionDescribe();
249 return nullptr;
250 }
251
252 // returned array contains both op handles and output indices, in pair
253 return output_handles_and_indices;
254 }
255
Java_org_tensorflow_Graph_whileLoop(JNIEnv * env,jclass clazz,jlong handle,jlongArray input_handles,jintArray input_indices,jstring name,jobject cond_graph_builder,jobject body_graph_builder)256 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
257 JNIEnv* env, jclass clazz, jlong handle, jlongArray input_handles,
258 jintArray input_indices, jstring name, jobject cond_graph_builder,
259 jobject body_graph_builder) {
260 TF_Graph* g = requireHandle(env, handle);
261 TF_Status* status = TF_NewStatus();
262 if (g == nullptr) return nullptr;
263
264 int ninputs = env->GetArrayLength(input_handles);
265
266 std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]);
267 resolveOutputs(env, "inputs", input_handles, input_indices, inputs.get(),
268 ninputs);
269 if (env->ExceptionCheck()) return nullptr;
270
271 // initialize while params
272 TF_WhileParams params = TF_NewWhile(g, inputs.get(), ninputs, status);
273 throwExceptionIfNotOK(env, status);
274
275 // build conditional subgraph
276 jlongArray cond_output_handles_and_indices =
277 buildSubgraph(env, clazz, cond_graph_builder, params.cond_graph,
278 params.cond_inputs, ¶ms.cond_output, params.ninputs, 1);
279
280 // build body subgraph
281 jlongArray body_output_handles_and_indices = buildSubgraph(
282 env, clazz, body_graph_builder, params.body_graph, params.body_inputs,
283 params.body_outputs, params.ninputs, params.ninputs);
284
285 if (cond_output_handles_and_indices == nullptr ||
286 body_output_handles_and_indices == nullptr)
287 return nullptr;
288
289 // set cond_output param to output of the conditional subgraph
290 jlong* cond_output_elems =
291 env->GetLongArrayElements(cond_output_handles_and_indices, nullptr);
292 TF_Operation* cond_output_op =
293 requireOperationHandle(env, cond_output_elems[0]);
294 params.cond_output = {cond_output_op,
295 static_cast<jint>(cond_output_elems[1])};
296 env->ReleaseLongArrayElements(cond_output_handles_and_indices,
297 cond_output_elems, 0);
298
299 // set body_outputs param to outputs of the body subgraph
300 jlong* body_output_elems =
301 env->GetLongArrayElements(body_output_handles_and_indices, nullptr);
302 for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
303 TF_Operation* body_output_op =
304 requireOperationHandle(env, body_output_elems[i]);
305 params.body_outputs[i] = {body_output_op,
306 static_cast<jint>(body_output_elems[j])};
307 }
308 env->ReleaseLongArrayElements(body_output_handles_and_indices,
309 body_output_elems, 0);
310
311 // set loop name param
312 params.name = env->GetStringUTFChars(name, nullptr);
313
314 // build the while loop, storing loop outputs in `outputs`
315 std::unique_ptr<TF_Output[]> outputs(new TF_Output[ninputs]);
316 TF_FinishWhile(¶ms, status, outputs.get());
317
318 throwExceptionIfNotOK(env, status);
319 TF_DeleteStatus(status);
320
321 env->ReleaseStringUTFChars(name, params.name);
322
323 // returned array contains both op handles and output indices, in pair
324 jlongArray output_handles_and_indices = env->NewLongArray(ninputs * 2);
325 jlong* output_elems =
326 env->GetLongArrayElements(output_handles_and_indices, nullptr);
327 for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
328 TF_Output output = outputs.get()[i];
329 output_elems[i] = reinterpret_cast<jlong>(output.oper);
330 output_elems[j] = static_cast<jlong>(output.index);
331 }
332 env->ReleaseLongArrayElements(output_handles_and_indices, output_elems, 0);
333
334 return output_handles_and_indices;
335 }
336