xref: /aosp_15_r20/external/tensorflow/tensorflow/java/src/main/native/graph_jni.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/java/src/main/native/graph_jni.h"
17 
18 #include <limits>
19 #include <memory>
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/java/src/main/native/exception_jni.h"
22 #include "tensorflow/java/src/main/native/utils_jni.h"
23 
24 namespace {
25 template <class T>
requireHandleImpl(JNIEnv * env,jlong handle)26 T* requireHandleImpl(JNIEnv* env, jlong handle) {
27   static_assert(sizeof(jlong) >= sizeof(T*),
28                 "Cannot package C object pointers as a Java long");
29   if (handle == 0) {
30     throwException(env, kIllegalStateException,
31                    "close() has been called on the Graph");
32     return nullptr;
33   }
34   return reinterpret_cast<T*>(handle);
35 }
36 
requireHandle(JNIEnv * env,jlong handle)37 TF_Graph* requireHandle(JNIEnv* env, jlong handle) {
38   return requireHandleImpl<TF_Graph>(env, handle);
39 }
40 
requireOperationHandle(JNIEnv * env,jlong handle)41 TF_Operation* requireOperationHandle(JNIEnv* env, jlong handle) {
42   return requireHandleImpl<TF_Operation>(env, handle);
43 }
44 }  // namespace
45 
Java_org_tensorflow_Graph_allocate(JNIEnv *,jclass)46 JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_allocate(JNIEnv*, jclass) {
47   return reinterpret_cast<jlong>(TF_NewGraph());
48 }
49 
Java_org_tensorflow_Graph_delete(JNIEnv *,jclass,jlong handle)50 JNIEXPORT void JNICALL Java_org_tensorflow_Graph_delete(JNIEnv*, jclass,
51                                                         jlong handle) {
52   if (handle == 0) return;
53   TF_DeleteGraph(reinterpret_cast<TF_Graph*>(handle));
54 }
55 
Java_org_tensorflow_Graph_operation(JNIEnv * env,jclass clazz,jlong handle,jstring name)56 JNIEXPORT jlong JNICALL Java_org_tensorflow_Graph_operation(JNIEnv* env,
57                                                             jclass clazz,
58                                                             jlong handle,
59                                                             jstring name) {
60   TF_Graph* g = requireHandle(env, handle);
61   if (g == nullptr) return 0;
62   const char* cname = env->GetStringUTFChars(name, nullptr);
63   TF_Operation* op = TF_GraphOperationByName(g, cname);
64   env->ReleaseStringUTFChars(name, cname);
65   return reinterpret_cast<jlong>(op);
66 }
67 
Java_org_tensorflow_Graph_nextOperation(JNIEnv * env,jclass clazz,jlong handle,jint position)68 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_nextOperation(
69     JNIEnv* env, jclass clazz, jlong handle, jint position) {
70   TF_Graph* g = requireHandle(env, handle);
71   if (g == nullptr) return nullptr;
72 
73   size_t pos = static_cast<size_t>(position);
74   TF_Operation* operation = TF_GraphNextOperation(g, &pos);
75   if (operation == nullptr) return nullptr;
76 
77   jlong handle_and_position[2];
78   handle_and_position[0] = reinterpret_cast<jlong>(operation);
79   handle_and_position[1] = static_cast<jlong>(pos);
80 
81   jlongArray rhett = env->NewLongArray(2);
82   env->SetLongArrayRegion(rhett, 0, 2, handle_and_position);
83   return rhett;
84 }
85 
Java_org_tensorflow_Graph_importGraphDef(JNIEnv * env,jclass clazz,jlong handle,jbyteArray graph_def,jstring prefix)86 JNIEXPORT void JNICALL Java_org_tensorflow_Graph_importGraphDef(
87     JNIEnv* env, jclass clazz, jlong handle, jbyteArray graph_def,
88     jstring prefix) {
89   TF_Graph* g = requireHandle(env, handle);
90   if (g == nullptr) return;
91 
92   TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
93 
94   jboolean is_copy;
95   const char* cprefix = env->GetStringUTFChars(prefix, &is_copy);
96   TF_ImportGraphDefOptionsSetPrefix(opts, cprefix);
97   env->ReleaseStringUTFChars(prefix, cprefix);
98 
99   static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
100   jbyte* bytes = env->GetByteArrayElements(graph_def, &is_copy);
101   TF_Buffer* buf =
102       TF_NewBufferFromString(bytes, env->GetArrayLength(graph_def));
103   TF_Status* status = TF_NewStatus();
104 
105   TF_GraphImportGraphDef(g, buf, opts, status);
106   throwExceptionIfNotOK(env, status);
107   // Continue cleaning up resources even if an exception was thrown.
108 
109   TF_DeleteStatus(status);
110   TF_DeleteBuffer(buf);
111   env->ReleaseByteArrayElements(graph_def, bytes, JNI_ABORT);
112 
113   TF_DeleteImportGraphDefOptions(opts);
114 }
115 
116 JNIEXPORT jbyteArray JNICALL
Java_org_tensorflow_Graph_toGraphDef(JNIEnv * env,jclass clazz,jlong handle)117 Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
118   jbyteArray ret = nullptr;
119   TF_Graph* g = requireHandle(env, handle);
120   if (g == nullptr) return ret;
121 
122   TF_Buffer* buf = TF_NewBuffer();
123   TF_Status* status = TF_NewStatus();
124   TF_GraphToGraphDef(g, buf, status);
125   if (throwExceptionIfNotOK(env, status)) {
126     // sizeof(jsize) is less than sizeof(size_t) on some platforms.
127     if (buf->length > std::numeric_limits<jint>::max()) {
128       throwException(env, kIndexOutOfBoundsException,
129                      "GraphDef is too large to serialize into a byte[] array");
130     } else {
131       static_assert(sizeof(jbyte) == 1, "unexpected size of the jbyte type");
132       jint ret_len = static_cast<jint>(buf->length);
133       ret = env->NewByteArray(ret_len);
134       env->SetByteArrayRegion(ret, 0, ret_len,
135                               static_cast<const jbyte*>(buf->data));
136     }
137   }
138   TF_DeleteStatus(status);
139   TF_DeleteBuffer(buf);
140   return ret;
141 }
142 
Java_org_tensorflow_Graph_addGradients(JNIEnv * env,jclass clazz,jlong handle,jstring prefix,jlongArray y_handles,jintArray y_indices,jlongArray x_handles,jintArray x_indices,jlongArray dx_handles,jintArray dx_indices)143 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
144     JNIEnv* env, jclass clazz, jlong handle, jstring prefix,
145     jlongArray y_handles, jintArray y_indices, jlongArray x_handles,
146     jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) {
147   TF_Graph* g = requireHandle(env, handle);
148   if (g == nullptr) return nullptr;
149 
150   const jint ny = env->GetArrayLength(y_handles);
151   const jint nx = env->GetArrayLength(x_handles);
152 
153   std::unique_ptr<TF_Output[]> y(new TF_Output[ny]);
154   std::unique_ptr<TF_Output[]> x(new TF_Output[nx]);
155   std::unique_ptr<TF_Output[]> dx(nullptr);
156   std::unique_ptr<TF_Output[]> dy(new TF_Output[nx]);
157 
158   resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny);
159   resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx);
160   if (dx_handles != nullptr) {
161     if (env->GetArrayLength(dx_handles) != ny) {
162       throwException(env, kIllegalArgumentException,
163                      "expected %d, got %d dx handles", ny,
164                      env->GetArrayLength(dx_handles));
165     }
166     dx.reset(new TF_Output[ny]);
167     resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny);
168   }
169   if (env->ExceptionCheck()) return nullptr;
170 
171   const char* cprefix = nullptr;
172   if (prefix != nullptr) {
173     cprefix = env->GetStringUTFChars(prefix, nullptr);
174   }
175   TF_Status* status = TF_NewStatus();
176   TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(),
177                             status, dy.get());
178   if (prefix != nullptr) {
179     env->ReleaseStringUTFChars(prefix, cprefix);
180   }
181   if (!throwExceptionIfNotOK(env, status)) {
182     TF_DeleteStatus(status);
183     return nullptr;
184   }
185   TF_DeleteStatus(status);
186 
187   // returned array contains both op handles and output indices, in pair
188   jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1);
189   jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr);
190   for (int i = 0, j = nx; i < nx; ++i, ++j) {
191     TF_Output dy_output = dy.get()[i];
192     dy_elems[i] = reinterpret_cast<jlong>(dy_output.oper);
193     dy_elems[j] = static_cast<jlong>(dy_output.index);
194   }
195   env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0);
196 
197   return dy_handles_and_indices;
198 }
199 
200 // helper function for while loop -- constructs conditional or body subgraph
buildSubgraph(JNIEnv * env,jclass clazz,jobject subgraph_builder,TF_Graph * const subgraph,const TF_Output * const inputs,const TF_Output * const outputs,const int ninputs,const int noutputs)201 jlongArray buildSubgraph(JNIEnv* env, jclass clazz, jobject subgraph_builder,
202                          TF_Graph* const subgraph,
203                          const TF_Output* const inputs,
204                          const TF_Output* const outputs, const int ninputs,
205                          const int noutputs) {
206   jmethodID build_subgraph_method_id = env->GetStaticMethodID(
207       clazz, "buildSubgraph",
208       "(Lorg/tensorflow/Graph$WhileSubgraphBuilder;J[J[I[J[I)[J");
209   if (build_subgraph_method_id == nullptr) return nullptr;
210 
211   jlong subgraph_handle = reinterpret_cast<jlong>(subgraph);
212 
213   jlongArray input_handles = env->NewLongArray(ninputs);
214   jintArray input_indices = env->NewIntArray(ninputs);
215   jlongArray output_handles = env->NewLongArray(noutputs);
216   jintArray output_indices = env->NewIntArray(noutputs);
217 
218   jlong* input_handles_elems =
219       env->GetLongArrayElements(input_handles, nullptr);
220   jint* input_indices_elems = env->GetIntArrayElements(input_indices, nullptr);
221   jlong* output_handles_elems =
222       env->GetLongArrayElements(output_handles, nullptr);
223   jint* output_indices_elems =
224       env->GetIntArrayElements(output_indices, nullptr);
225 
226   for (int i = 0; i < ninputs; ++i) {
227     input_handles_elems[i] = reinterpret_cast<jlong>((inputs[i]).oper);
228     input_indices_elems[i] = static_cast<jint>((inputs[i]).index);
229   }
230 
231   for (int i = 0; i < noutputs; ++i) {
232     output_handles_elems[i] = reinterpret_cast<jlong>((outputs[i]).oper);
233     output_indices_elems[i] = static_cast<jint>((outputs[i]).index);
234   }
235 
236   env->ReleaseLongArrayElements(input_handles, input_handles_elems, 0);
237   env->ReleaseIntArrayElements(input_indices, input_indices_elems, 0);
238   env->ReleaseLongArrayElements(output_handles, output_handles_elems, 0);
239   env->ReleaseIntArrayElements(output_indices, output_indices_elems, 0);
240 
241   // call Java code to construct the subgraph
242   jlongArray output_handles_and_indices =
243       (jlongArray)env->CallStaticObjectMethod(
244           clazz, build_subgraph_method_id, subgraph_builder, subgraph_handle,
245           input_handles, input_indices, output_handles, output_indices);
246 
247   if (env->ExceptionOccurred()) {
248     env->ExceptionDescribe();
249     return nullptr;
250   }
251 
252   // returned array contains both op handles and output indices, in pair
253   return output_handles_and_indices;
254 }
255 
Java_org_tensorflow_Graph_whileLoop(JNIEnv * env,jclass clazz,jlong handle,jlongArray input_handles,jintArray input_indices,jstring name,jobject cond_graph_builder,jobject body_graph_builder)256 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_whileLoop(
257     JNIEnv* env, jclass clazz, jlong handle, jlongArray input_handles,
258     jintArray input_indices, jstring name, jobject cond_graph_builder,
259     jobject body_graph_builder) {
260   TF_Graph* g = requireHandle(env, handle);
261   TF_Status* status = TF_NewStatus();
262   if (g == nullptr) return nullptr;
263 
264   int ninputs = env->GetArrayLength(input_handles);
265 
266   std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]);
267   resolveOutputs(env, "inputs", input_handles, input_indices, inputs.get(),
268                  ninputs);
269   if (env->ExceptionCheck()) return nullptr;
270 
271   // initialize while params
272   TF_WhileParams params = TF_NewWhile(g, inputs.get(), ninputs, status);
273   throwExceptionIfNotOK(env, status);
274 
275   // build conditional subgraph
276   jlongArray cond_output_handles_and_indices =
277       buildSubgraph(env, clazz, cond_graph_builder, params.cond_graph,
278                     params.cond_inputs, &params.cond_output, params.ninputs, 1);
279 
280   // build body subgraph
281   jlongArray body_output_handles_and_indices = buildSubgraph(
282       env, clazz, body_graph_builder, params.body_graph, params.body_inputs,
283       params.body_outputs, params.ninputs, params.ninputs);
284 
285   if (cond_output_handles_and_indices == nullptr ||
286       body_output_handles_and_indices == nullptr)
287     return nullptr;
288 
289   // set cond_output param to output of the conditional subgraph
290   jlong* cond_output_elems =
291       env->GetLongArrayElements(cond_output_handles_and_indices, nullptr);
292   TF_Operation* cond_output_op =
293       requireOperationHandle(env, cond_output_elems[0]);
294   params.cond_output = {cond_output_op,
295                         static_cast<jint>(cond_output_elems[1])};
296   env->ReleaseLongArrayElements(cond_output_handles_and_indices,
297                                 cond_output_elems, 0);
298 
299   // set body_outputs param to outputs of the body subgraph
300   jlong* body_output_elems =
301       env->GetLongArrayElements(body_output_handles_and_indices, nullptr);
302   for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
303     TF_Operation* body_output_op =
304         requireOperationHandle(env, body_output_elems[i]);
305     params.body_outputs[i] = {body_output_op,
306                               static_cast<jint>(body_output_elems[j])};
307   }
308   env->ReleaseLongArrayElements(body_output_handles_and_indices,
309                                 body_output_elems, 0);
310 
311   // set loop name param
312   params.name = env->GetStringUTFChars(name, nullptr);
313 
314   // build the while loop, storing loop outputs in `outputs`
315   std::unique_ptr<TF_Output[]> outputs(new TF_Output[ninputs]);
316   TF_FinishWhile(&params, status, outputs.get());
317 
318   throwExceptionIfNotOK(env, status);
319   TF_DeleteStatus(status);
320 
321   env->ReleaseStringUTFChars(name, params.name);
322 
323   // returned array contains both op handles and output indices, in pair
324   jlongArray output_handles_and_indices = env->NewLongArray(ninputs * 2);
325   jlong* output_elems =
326       env->GetLongArrayElements(output_handles_and_indices, nullptr);
327   for (int i = 0, j = ninputs; i < ninputs; ++i, ++j) {
328     TF_Output output = outputs.get()[i];
329     output_elems[i] = reinterpret_cast<jlong>(output.oper);
330     output_elems[j] = static_cast<jlong>(output.index);
331   }
332   env->ReleaseLongArrayElements(output_handles_and_indices, output_elems, 0);
333 
334   return output_handles_and_indices;
335 }
336