xref: /aosp_15_r20/external/tensorflow/tensorflow/c/c_test_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_test_util.h"
17 
18 #include "tensorflow/c/c_api_experimental.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/op_def.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/strcat.h"
24 #include "tensorflow/core/public/session_options.h"
25 
26 using tensorflow::GraphDef;
27 using tensorflow::NodeDef;
28 
BoolDeallocator(void * data,size_t,void * arg)29 static void BoolDeallocator(void* data, size_t, void* arg) {
30   delete[] static_cast<bool*>(data);
31 }
32 
Int32Deallocator(void * data,size_t,void * arg)33 static void Int32Deallocator(void* data, size_t, void* arg) {
34   delete[] static_cast<int32_t*>(data);
35 }
36 
DoubleDeallocator(void * data,size_t,void * arg)37 static void DoubleDeallocator(void* data, size_t, void* arg) {
38   delete[] static_cast<double*>(data);
39 }
40 
FloatDeallocator(void * data,size_t,void * arg)41 static void FloatDeallocator(void* data, size_t, void* arg) {
42   delete[] static_cast<float*>(data);
43 }
44 
BoolTensor(bool v)45 TF_Tensor* BoolTensor(bool v) {
46   const int num_bytes = sizeof(bool);
47   bool* values = new bool[1];
48   values[0] = v;
49   return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
50                       nullptr);
51 }
52 
Int8Tensor(const int64_t * dims,int num_dims,const char * values)53 TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
54   int64_t num_values = 1;
55   for (int i = 0; i < num_dims; ++i) {
56     num_values *= dims[i];
57   }
58   TF_Tensor* t =
59       TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values);
60   memcpy(TF_TensorData(t), values, sizeof(char) * num_values);
61   return t;
62 }
63 
Int32Tensor(const int64_t * dims,int num_dims,const int32_t * values)64 TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
65                        const int32_t* values) {
66   int64_t num_values = 1;
67   for (int i = 0; i < num_dims; ++i) {
68     num_values *= dims[i];
69   }
70   TF_Tensor* t =
71       TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
72   memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
73   return t;
74 }
75 
Int32Tensor(const std::vector<int32_t> & values)76 TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
77   int64_t dims = values.size();
78   return Int32Tensor(&dims, 1, values.data());
79 }
80 
Int32Tensor(int32_t v)81 TF_Tensor* Int32Tensor(int32_t v) {
82   const int num_bytes = sizeof(int32_t);
83   int32_t* values = new int32_t[1];
84   values[0] = v;
85   return TF_NewTensor(TF_INT32, nullptr, 0, values, num_bytes,
86                       &Int32Deallocator, nullptr);
87 }
88 
DoubleTensor(double v)89 TF_Tensor* DoubleTensor(double v) {
90   const int num_bytes = sizeof(double);
91   double* values = new double[1];
92   values[0] = v;
93   return TF_NewTensor(TF_DOUBLE, nullptr, 0, values, num_bytes,
94                       &DoubleDeallocator, nullptr);
95 }
96 
FloatTensor(float v)97 TF_Tensor* FloatTensor(float v) {
98   const int num_bytes = sizeof(float);
99   float* values = new float[1];
100   values[0] = v;
101   return TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes,
102                       &FloatDeallocator, nullptr);
103 }
104 
105 // All the *Helper methods are used as a workaround for the restrictions that
106 // one cannot call ASSERT_* methods in non-void-returning functions (when
107 // exceptions are disabled during compilation)
PlaceholderHelper(TF_Graph * graph,TF_Status * s,const char * name,TF_DataType dtype,const std::vector<int64_t> & dims,TF_Operation ** op)108 void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
109                        TF_DataType dtype, const std::vector<int64_t>& dims,
110                        TF_Operation** op) {
111   TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
112   TF_SetAttrType(desc, "dtype", dtype);
113   if (!dims.empty()) {
114     TF_SetAttrShape(desc, "shape", dims.data(), dims.size());
115   }
116   *op = TF_FinishOperation(desc, s);
117   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
118   ASSERT_NE(*op, nullptr);
119 }
120 
Placeholder(TF_Graph * graph,TF_Status * s,const char * name,TF_DataType dtype,const std::vector<int64_t> & dims)121 TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name,
122                           TF_DataType dtype, const std::vector<int64_t>& dims) {
123   TF_Operation* op;
124   PlaceholderHelper(graph, s, name, dtype, dims, &op);
125   return op;
126 }
127 
ConstHelper(TF_Tensor * t,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op)128 void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
129                  TF_Operation** op) {
130   TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
131   TF_SetAttrTensor(desc, "value", t, s);
132   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
133   TF_SetAttrType(desc, "dtype", TF_TensorType(t));
134   *op = TF_FinishOperation(desc, s);
135   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
136   ASSERT_NE(*op, nullptr);
137 }
138 
Const(TF_Tensor * t,TF_Graph * graph,TF_Status * s,const char * name)139 TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
140                     const char* name) {
141   TF_Operation* op;
142   ConstHelper(t, graph, s, name, &op);
143   return op;
144 }
145 
ScalarConst(bool v,TF_Graph * graph,TF_Status * s,const char * name)146 TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
147                           const char* name) {
148   unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
149   return Const(tensor.get(), graph, s, name);
150 }
151 
ScalarConst(int32_t v,TF_Graph * graph,TF_Status * s,const char * name)152 TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
153                           const char* name) {
154   unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
155   return Const(tensor.get(), graph, s, name);
156 }
157 
ScalarConst(double v,TF_Graph * graph,TF_Status * s,const char * name)158 TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
159                           const char* name) {
160   unique_tensor_ptr tensor(DoubleTensor(v), TF_DeleteTensor);
161   return Const(tensor.get(), graph, s, name);
162 }
163 
ScalarConst(float v,TF_Graph * graph,TF_Status * s,const char * name)164 TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s,
165                           const char* name) {
166   unique_tensor_ptr tensor(FloatTensor(v), TF_DeleteTensor);
167   return Const(tensor.get(), graph, s, name);
168 }
169 
AddOpHelper(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op,bool check)170 void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
171                  TF_Status* s, const char* name, TF_Operation** op,
172                  bool check) {
173   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
174   TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
175   TF_AddInputList(desc, add_inputs, 2);
176   *op = TF_FinishOperation(desc, s);
177   if (check) {
178     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
179     ASSERT_NE(*op, nullptr);
180   }
181 }
182 
Add(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)183 TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
184                   TF_Status* s, const char* name) {
185   TF_Operation* op;
186   AddOpHelper(l, r, graph, s, name, &op, true);
187   return op;
188 }
189 
AddNoCheck(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)190 TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
191                          TF_Status* s, const char* name) {
192   TF_Operation* op;
193   AddOpHelper(l, r, graph, s, name, &op, false);
194   return op;
195 }
196 
AddWithCtrlDependency(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Operation * ctrl_op,TF_Status * s,const char * name)197 TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
198                                     TF_Graph* graph, TF_Operation* ctrl_op,
199                                     TF_Status* s, const char* name) {
200   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
201   TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
202   TF_AddInputList(desc, add_inputs, 2);
203   TF_AddControlInput(desc, ctrl_op);
204   return TF_FinishOperation(desc, s);
205 }
206 
207 // If `op_device` is non-empty, set the created op on that device.
BinaryOpHelper(const char * op_name,TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op,const string & op_device,bool check)208 void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r,
209                     TF_Graph* graph, TF_Status* s, const char* name,
210                     TF_Operation** op, const string& op_device, bool check) {
211   TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name);
212   if (!op_device.empty()) {
213     TF_SetDevice(desc, op_device.c_str());
214   }
215   TF_AddInput(desc, {l, 0});
216   TF_AddInput(desc, {r, 0});
217   *op = TF_FinishOperation(desc, s);
218   if (check) {
219     ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
220     ASSERT_NE(*op, nullptr);
221   }
222 }
223 
MinWithDevice(TF_Operation * l,TF_Operation * r,TF_Graph * graph,const string & op_device,TF_Status * s,const char * name)224 TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
225                             const string& op_device, TF_Status* s,
226                             const char* name) {
227   TF_Operation* op;
228   BinaryOpHelper("Min", l, r, graph, s, name, &op, op_device, true);
229   return op;
230 }
231 
Min(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)232 TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
233                   TF_Status* s, const char* name) {
234   return MinWithDevice(l, r, graph, /*op_device=*/"", s, name);
235 }
236 
Mul(TF_Operation * l,TF_Operation * r,TF_Graph * graph,TF_Status * s,const char * name)237 TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
238                   TF_Status* s, const char* name) {
239   TF_Operation* op;
240   BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true);
241   return op;
242 }
243 
Add(TF_Output l,TF_Output r,TF_Graph * graph,TF_Status * s,const char * name)244 TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
245                   const char* name) {
246   TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
247   TF_Output inputs[2] = {l, r};
248   TF_AddInputList(desc, inputs, 2);
249   return TF_FinishOperation(desc, s);
250 }
251 
NegHelper(TF_Operation * n,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op)252 void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s, const char* name,
253                TF_Operation** op) {
254   TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", name);
255   TF_Output neg_input = {n, 0};
256   TF_AddInput(desc, neg_input);
257   *op = TF_FinishOperation(desc, s);
258   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
259   ASSERT_NE(*op, nullptr);
260 }
261 
Neg(TF_Operation * n,TF_Graph * graph,TF_Status * s,const char * name)262 TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,
263                   const char* name) {
264   TF_Operation* op;
265   NegHelper(n, graph, s, name, &op);
266   return op;
267 }
268 
LessThan(TF_Output l,TF_Output r,TF_Graph * graph,TF_Status * s)269 TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
270                        TF_Status* s) {
271   TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than");
272   TF_AddInput(desc, l);
273   TF_AddInput(desc, r);
274   return TF_FinishOperation(desc, s);
275 }
276 
RandomUniform(TF_Operation * shape,TF_DataType dtype,TF_Graph * graph,TF_Status * s)277 TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype,
278                             TF_Graph* graph, TF_Status* s) {
279   TF_OperationDescription* desc =
280       TF_NewOperation(graph, "RandomUniform", "random_uniform");
281   TF_AddInput(desc, {shape, 0});
282   TF_SetAttrType(desc, "dtype", dtype);
283   return TF_FinishOperation(desc, s);
284 }
285 
Split3Helper(TF_Operation * input,TF_Graph * graph,TF_Status * s,const char * name,TF_Operation ** op)286 void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
287                   const char* name, TF_Operation** op) {
288   TF_Operation* zero = ScalarConst(
289       0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
290   TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
291   TF_AddInput(desc, {zero, 0});
292   TF_AddInput(desc, {input, 0});
293   TF_SetAttrInt(desc, "num_split", 3);
294   TF_SetAttrType(desc, "T", TF_INT32);
295   // Set device to CPU since there is no version of split for int32 on GPU
296   // TODO(iga): Convert all these helpers and tests to use floats because
297   // they are usually available on GPUs. After doing this, remove TF_SetDevice
298   // call in c_api_function_test.cc
299   TF_SetDevice(desc, "/cpu:0");
300   *op = TF_FinishOperation(desc, s);
301   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
302   ASSERT_NE(*op, nullptr);
303 }
304 
Split3(TF_Operation * input,TF_Graph * graph,TF_Status * s,const char * name)305 TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
306                      const char* name) {
307   TF_Operation* op;
308   Split3Helper(input, graph, s, name, &op);
309   return op;
310 }
311 
IsPlaceholder(const tensorflow::NodeDef & node_def)312 bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
313   if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
314     return false;
315   }
316   bool found_dtype = false;
317   bool found_shape = false;
318   for (const auto& attr : node_def.attr()) {
319     if (attr.first == "dtype") {
320       if (attr.second.type() == tensorflow::DT_INT32) {
321         found_dtype = true;
322       } else {
323         return false;
324       }
325     } else if (attr.first == "shape") {
326       found_shape = true;
327     }
328   }
329   return found_dtype && found_shape;
330 }
331 
IsScalarConst(const tensorflow::NodeDef & node_def,int v)332 bool IsScalarConst(const tensorflow::NodeDef& node_def, int v) {
333   if (node_def.op() != "Const" || node_def.name() != "scalar") {
334     return false;
335   }
336   bool found_dtype = false;
337   bool found_value = false;
338   for (const auto& attr : node_def.attr()) {
339     if (attr.first == "dtype") {
340       if (attr.second.type() == tensorflow::DT_INT32) {
341         found_dtype = true;
342       } else {
343         return false;
344       }
345     } else if (attr.first == "value") {
346       if (attr.second.has_tensor() &&
347           attr.second.tensor().int_val_size() == 1 &&
348           attr.second.tensor().int_val(0) == v) {
349         found_value = true;
350       } else {
351         return false;
352       }
353     }
354   }
355   return found_dtype && found_value;
356 }
357 
IsAddN(const tensorflow::NodeDef & node_def,int n)358 bool IsAddN(const tensorflow::NodeDef& node_def, int n) {
359   if (node_def.op() != "AddN" || node_def.name() != "add" ||
360       node_def.input_size() != n) {
361     return false;
362   }
363   bool found_t = false;
364   bool found_n = false;
365   for (const auto& attr : node_def.attr()) {
366     if (attr.first == "T") {
367       if (attr.second.type() == tensorflow::DT_INT32) {
368         found_t = true;
369       } else {
370         return false;
371       }
372     } else if (attr.first == "N") {
373       if (attr.second.i() == n) {
374         found_n = true;
375       } else {
376         return false;
377       }
378     }
379   }
380   return found_t && found_n;
381 }
382 
IsNeg(const tensorflow::NodeDef & node_def,const string & input)383 bool IsNeg(const tensorflow::NodeDef& node_def, const string& input) {
384   return node_def.op() == "Neg" && node_def.name() == "neg" &&
385          node_def.input_size() == 1 && node_def.input(0) == input;
386 }
387 
GetGraphDef(TF_Graph * graph,tensorflow::GraphDef * graph_def)388 bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def) {
389   TF_Status* s = TF_NewStatus();
390   TF_Buffer* buffer = TF_NewBuffer();
391   TF_GraphToGraphDef(graph, buffer, s);
392   bool ret = TF_GetCode(s) == TF_OK;
393   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
394   if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
395   TF_DeleteBuffer(buffer);
396   TF_DeleteStatus(s);
397   return ret;
398 }
399 
GetNodeDef(TF_Operation * oper,tensorflow::NodeDef * node_def)400 bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
401   TF_Status* s = TF_NewStatus();
402   TF_Buffer* buffer = TF_NewBuffer();
403   TF_OperationToNodeDef(oper, buffer, s);
404   bool ret = TF_GetCode(s) == TF_OK;
405   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
406   if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length);
407   TF_DeleteBuffer(buffer);
408   TF_DeleteStatus(s);
409   return ret;
410 }
411 
GetFunctionDef(TF_Function * func,tensorflow::FunctionDef * func_def)412 bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
413   TF_Status* s = TF_NewStatus();
414   TF_Buffer* buffer = TF_NewBuffer();
415   TF_FunctionToFunctionDef(func, buffer, s);
416   bool ret = TF_GetCode(s) == TF_OK;
417   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
418   if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
419   TF_DeleteBuffer(buffer);
420   TF_DeleteStatus(s);
421   return ret;
422 }
423 
GetAttrValue(TF_Operation * oper,const char * attr_name,tensorflow::AttrValue * attr_value,TF_Status * s)424 bool GetAttrValue(TF_Operation* oper, const char* attr_name,
425                   tensorflow::AttrValue* attr_value, TF_Status* s) {
426   TF_Buffer* buffer = TF_NewBuffer();
427   TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
428   bool ret = TF_GetCode(s) == TF_OK;
429   if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length);
430   TF_DeleteBuffer(buffer);
431   return ret;
432 }
433 
GetGradDefs(const tensorflow::GraphDef & graph_def)434 std::vector<std::pair<string, string>> GetGradDefs(
435     const tensorflow::GraphDef& graph_def) {
436   std::vector<std::pair<string, string>> grads;
437   for (const tensorflow::GradientDef& grad : graph_def.library().gradient()) {
438     grads.emplace_back(grad.function_name(), grad.gradient_func());
439   }
440   std::sort(grads.begin(), grads.end());
441   return grads;
442 }
443 
GetFuncNames(const tensorflow::GraphDef & graph_def)444 std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def) {
445   std::vector<string> names;
446   auto functions = graph_def.library().function();
447   names.reserve(functions.size());
448   for (const tensorflow::FunctionDef& func : functions) {
449     names.push_back(func.signature().name());
450   }
451   std::sort(names.begin(), names.end());
452   return names;
453 }
454 
CSession(TF_Graph * graph,TF_Status * s,bool use_XLA)455 CSession::CSession(TF_Graph* graph, TF_Status* s, bool use_XLA) {
456   TF_SessionOptions* opts = TF_NewSessionOptions();
457   TF_EnableXLACompilation(opts, use_XLA);
458   session_ = TF_NewSession(graph, opts, s);
459   TF_DeleteSessionOptions(opts);
460 }
461 
CSession(TF_Session * session)462 CSession::CSession(TF_Session* session) : session_(session) {}
463 
~CSession()464 CSession::~CSession() {
465   TF_Status* s = TF_NewStatus();
466   CloseAndDelete(s);
467   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
468   TF_DeleteStatus(s);
469 }
470 
SetInputs(std::vector<std::pair<TF_Operation *,TF_Tensor * >> inputs)471 void CSession::SetInputs(
472     std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
473   DeleteInputValues();
474   inputs_.clear();
475   for (const auto& p : inputs) {
476     inputs_.emplace_back(TF_Output{p.first, 0});
477     input_values_.emplace_back(p.second);
478   }
479 }
480 
SetOutputs(std::initializer_list<TF_Operation * > outputs)481 void CSession::SetOutputs(std::initializer_list<TF_Operation*> outputs) {
482   ResetOutputValues();
483   outputs_.clear();
484   for (TF_Operation* o : outputs) {
485     outputs_.emplace_back(TF_Output{o, 0});
486   }
487   output_values_.resize(outputs_.size());
488 }
489 
SetOutputs(const std::vector<TF_Output> & outputs)490 void CSession::SetOutputs(const std::vector<TF_Output>& outputs) {
491   ResetOutputValues();
492   outputs_ = outputs;
493   output_values_.resize(outputs_.size());
494 }
495 
SetTargets(std::initializer_list<TF_Operation * > targets)496 void CSession::SetTargets(std::initializer_list<TF_Operation*> targets) {
497   targets_.clear();
498   for (TF_Operation* t : targets) {
499     targets_.emplace_back(t);
500   }
501 }
502 
Run(TF_Status * s)503 void CSession::Run(TF_Status* s) {
504   if (inputs_.size() != input_values_.size()) {
505     ADD_FAILURE() << "Call SetInputs() before Run()";
506     return;
507   }
508   ResetOutputValues();
509   output_values_.resize(outputs_.size(), nullptr);
510 
511   const TF_Output* inputs_ptr = inputs_.empty() ? nullptr : &inputs_[0];
512   TF_Tensor* const* input_values_ptr =
513       input_values_.empty() ? nullptr : &input_values_[0];
514 
515   const TF_Output* outputs_ptr = outputs_.empty() ? nullptr : &outputs_[0];
516   TF_Tensor** output_values_ptr =
517       output_values_.empty() ? nullptr : &output_values_[0];
518 
519   TF_Operation* const* targets_ptr = targets_.empty() ? nullptr : &targets_[0];
520 
521   TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr, inputs_.size(),
522                 outputs_ptr, output_values_ptr, outputs_.size(), targets_ptr,
523                 targets_.size(), nullptr, s);
524 
525   DeleteInputValues();
526 }
527 
CloseAndDelete(TF_Status * s)528 void CSession::CloseAndDelete(TF_Status* s) {
529   DeleteInputValues();
530   ResetOutputValues();
531   if (session_ != nullptr) {
532     TF_CloseSession(session_, s);
533     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
534     TF_DeleteSession(session_, s);
535     session_ = nullptr;
536   }
537 }
538 
DeleteInputValues()539 void CSession::DeleteInputValues() {
540   for (size_t i = 0; i < input_values_.size(); ++i) {
541     TF_DeleteTensor(input_values_[i]);
542   }
543   input_values_.clear();
544 }
545 
ResetOutputValues()546 void CSession::ResetOutputValues() {
547   for (size_t i = 0; i < output_values_.size(); ++i) {
548     if (output_values_[i] != nullptr) TF_DeleteTensor(output_values_[i]);
549   }
550   output_values_.clear();
551 }
552