xref: /aosp_15_r20/external/tensorflow/tensorflow/c/c_api_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <iterator>
21 #include <memory>
22 #include <vector>
23 
24 #include "tensorflow/c/c_api_internal.h"
25 #include "tensorflow/c/c_test_util.h"
26 #include "tensorflow/c/tf_buffer_internal.h"
27 #include "tensorflow/c/tf_status.h"
28 #include "tensorflow/cc/saved_model/signature_constants.h"
29 #include "tensorflow/cc/saved_model/tag_constants.h"
30 #include "tensorflow/core/example/example.pb.h"
31 #include "tensorflow/core/example/feature.pb.h"
32 #include "tensorflow/core/framework/api_def.pb.h"
33 #include "tensorflow/core/framework/common_shape_fns.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/kernel_def.pb.h"
36 #include "tensorflow/core/framework/node_def.pb.h"
37 #include "tensorflow/core/framework/node_def_util.h"
38 #include "tensorflow/core/framework/op.h"
39 #include "tensorflow/core/framework/op_def.pb.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/partial_tensor_shape.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/tensor.pb.h"
44 #include "tensorflow/core/framework/tensor_shape.pb.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/core/graph/tensor_id.h"
47 #include "tensorflow/core/lib/core/status_test_util.h"
48 #include "tensorflow/core/lib/io/path.h"
49 #include "tensorflow/core/platform/path.h"
50 #include "tensorflow/core/platform/protobuf.h"
51 #include "tensorflow/core/platform/resource_loader.h"
52 #include "tensorflow/core/platform/str_util.h"
53 #include "tensorflow/core/platform/strcat.h"
54 #include "tensorflow/core/platform/test.h"
55 #include "tensorflow/core/protobuf/error_codes.pb.h"
56 #include "tensorflow/core/protobuf/meta_graph.pb.h"
57 #include "tensorflow/core/util/equal_graph_def.h"
58 
59 namespace tensorflow {
60 TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
61 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
62 
63 namespace {
64 
ExpectHasSubstr(StringPiece s,StringPiece expected)65 static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
66   EXPECT_TRUE(absl::StrContains(s, expected))
67       << "'" << s << "' does not contain '" << expected << "'";
68 }
69 
70 // Returns the GPU device name if there is one (with arbitrary tie breaking if
71 // there are more than one), or "" otherwise.
GPUDeviceName(TF_Session * session)72 string GPUDeviceName(TF_Session* session) {
73   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
74       TF_NewStatus(), TF_DeleteStatus);
75   TF_Status* s = status.get();
76   std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> list(
77       TF_SessionListDevices(session, s), TF_DeleteDeviceList);
78   TF_DeviceList* device_list = list.get();
79 
80   CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
81 
82   const int num_devices = TF_DeviceListCount(device_list);
83   LOG(INFO) << "There are " << num_devices << " devices.";
84   for (int i = 0; i < num_devices; ++i) {
85     const char* device_name = TF_DeviceListName(device_list, i, s);
86     CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
87     const char* device_type = TF_DeviceListType(device_list, i, s);
88     CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
89     LOG(INFO) << "Device " << i << " has name " << device_name << ", type "
90               << device_type;
91     if (string(device_type) == DEVICE_GPU) {
92       return device_name;
93     }
94   }
95   // No GPU device found.
96   return "";
97 }
98 
GPUDeviceName()99 string GPUDeviceName() {
100   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
101       TF_NewStatus(), TF_DeleteStatus);
102   TF_Status* s = status.get();
103   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph(TF_NewGraph(),
104                                                              TF_DeleteGraph);
105 
106   TF_SessionOptions* opts = TF_NewSessionOptions();
107   TF_Session* sess = TF_NewSession(graph.get(), opts, s);
108   TF_DeleteSessionOptions(opts);
109 
110   const string gpu_device_name = GPUDeviceName(sess);
111   TF_DeleteSession(sess, s);
112   CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
113   return gpu_device_name;
114 }
115 
TEST(CAPI,Version)116 TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); }
117 
TEST(CAPI,Status)118 TEST(CAPI, Status) {
119   TF_Status* s = TF_NewStatus();
120   EXPECT_EQ(TF_OK, TF_GetCode(s));
121   EXPECT_EQ(string(), TF_Message(s));
122   TF_SetStatus(s, TF_CANCELLED, "cancel");
123   EXPECT_EQ(TF_CANCELLED, TF_GetCode(s));
124   EXPECT_EQ(string("cancel"), TF_Message(s));
125   TF_DeleteStatus(s);
126 }
127 
Deallocator(void * data,size_t,void * arg)128 void Deallocator(void* data, size_t, void* arg) {
129   tensorflow::cpu_allocator()->DeallocateRaw(data);
130   *reinterpret_cast<bool*>(arg) = true;
131 }
132 
TEST(CAPI,Tensor)133 TEST(CAPI, Tensor) {
134   const int num_bytes = 6 * sizeof(float);
135   float* values =
136       reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
137           EIGEN_MAX_ALIGN_BYTES, num_bytes));
138   int64_t dims[] = {2, 3};
139   bool deallocator_called = false;
140   TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
141                               &Deallocator, &deallocator_called);
142   EXPECT_FALSE(deallocator_called);
143   EXPECT_EQ(TF_FLOAT, TF_TensorType(t));
144   EXPECT_EQ(2, TF_NumDims(t));
145   EXPECT_EQ(dims[0], TF_Dim(t, 0));
146   EXPECT_EQ(dims[1], TF_Dim(t, 1));
147   EXPECT_EQ(num_bytes, TF_TensorByteSize(t));
148   EXPECT_EQ(static_cast<void*>(values), TF_TensorData(t));
149   TF_DeleteTensor(t);
150   EXPECT_TRUE(deallocator_called);
151 }
152 
NoOpDeallocator(void * data,size_t,void *)153 void NoOpDeallocator(void* data, size_t, void*) {}
154 
TEST(CAPI,MalformedTensor)155 TEST(CAPI, MalformedTensor) {
156   // See https://github.com/tensorflow/tensorflow/issues/7394
157   // num_dims = 0 implies a scalar, so should be backed by at least 4 bytes of
158   // data.
159   TF_Tensor* t =
160       TF_NewTensor(TF_FLOAT, nullptr, 0, nullptr, 0, &NoOpDeallocator, nullptr);
161   ASSERT_TRUE(t == nullptr);
162 }
163 
TEST(CAPI,AllocateTensor)164 TEST(CAPI, AllocateTensor) {
165   const int num_bytes = 6 * sizeof(float);
166   int64_t dims[] = {2, 3};
167   TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, num_bytes);
168   EXPECT_EQ(TF_FLOAT, TF_TensorType(t));
169   EXPECT_EQ(2, TF_NumDims(t));
170   EXPECT_EQ(dims[0], TF_Dim(t, 0));
171   EXPECT_EQ(dims[1], TF_Dim(t, 1));
172   EXPECT_EQ(num_bytes, TF_TensorByteSize(t));
173   EXPECT_EQ(6, TF_TensorElementCount(t));
174   TF_DeleteTensor(t);
175 }
176 
TEST(CAPI,MaybeMove)177 TEST(CAPI, MaybeMove) {
178   const int num_bytes = 6 * sizeof(float);
179   float* values =
180       reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
181           EIGEN_MAX_ALIGN_BYTES, num_bytes));
182   int64_t dims[] = {2, 3};
183   bool deallocator_called = false;
184   TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
185                               &Deallocator, &deallocator_called);
186 
187   TF_Tensor* o = TF_TensorMaybeMove(t);
188   ASSERT_TRUE(o == nullptr);  // It is unsafe to move memory TF might not own.
189   TF_DeleteTensor(t);
190   EXPECT_TRUE(deallocator_called);
191 }
192 
TEST(CAPI,LibraryLoadFunctions)193 TEST(CAPI, LibraryLoadFunctions) {
194   // TODO(b/73318067): Fix linking for the GPU test generated by the
195   // tf_cuda_cc_test() bazel rule and remove the next line.
196   if (!GPUDeviceName().empty()) return;
197 
198 #if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
199   {
200     // Load the library.
201     TF_Status* status = TF_NewStatus();
202     string lib_path = tensorflow::GetDataDependencyFilepath(
203         tensorflow::io::JoinPath("tensorflow", "c", "test_op1.so"));
204     TF_Library* lib = TF_LoadLibrary(lib_path.c_str(), status);
205     TF_Code code = TF_GetCode(status);
206     string status_msg(TF_Message(status));
207     TF_DeleteStatus(status);
208     ASSERT_EQ(TF_OK, code) << status_msg;
209 
210     // Test op list.
211     TF_Buffer op_list_buf = TF_GetOpList(lib);
212     tensorflow::OpList op_list;
213     EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length));
214     ASSERT_EQ(op_list.op_size(), 1);
215     EXPECT_EQ("TestCApi1", op_list.op(0).name());
216     TF_DeleteLibraryHandle(lib);
217   }
218 #endif  // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
219   {
220     TF_Buffer* op_list_buffer = TF_GetAllOpList();
221     tensorflow::OpList op_list;
222     op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
223     ASSERT_GE(op_list.op_size(), 1);
224     typedef tensorflow::protobuf::RepeatedPtrField<tensorflow::OpDef> OpDefs;
225     const OpDefs& ops = op_list.op();
226     bool found = std::find_if(ops.begin(), ops.end(),
227                               [](const tensorflow::OpDef& op_def) {
228                                 return op_def.name() == "TestCApi";
229                               }) != ops.end();
230     EXPECT_TRUE(found);
231     TF_DeleteBuffer(op_list_buffer);
232   }
233 }
234 
TestEncodeDecode(int line,const std::vector<string> & data)235 void TestEncodeDecode(int line, const std::vector<string>& data) {
236   const int64_t n = data.size();
237   Status status;
238   for (const std::vector<int64_t>& dims :
239        std::vector<std::vector<int64_t>>{{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
240     // Create C++ Tensor
241     Tensor src(tensorflow::DT_STRING, TensorShape(dims));
242     for (int64_t i = 0; i < src.NumElements(); ++i) {
243       src.flat<tstring>()(i) = data[i];
244     }
245     TF_Tensor* dst = TF_TensorFromTensor(src, &status);
246     ASSERT_TRUE(status.ok()) << status.error_message();
247 
248     // Convert back to a C++ Tensor and ensure we get expected output.
249     Tensor output;
250     ASSERT_EQ(OkStatus(), TF_TensorToTensor(dst, &output)) << line;
251     ASSERT_EQ(src.NumElements(), output.NumElements()) << line;
252     for (int64_t i = 0; i < src.NumElements(); ++i) {
253       ASSERT_EQ(data[i], output.flat<tstring>()(i)) << line;
254     }
255 
256     TF_DeleteTensor(dst);
257   }
258 }
259 
TEST(CAPI,TensorEncodeDecodeStrings)260 TEST(CAPI, TensorEncodeDecodeStrings) {
261   TestEncodeDecode(__LINE__, {});
262   TestEncodeDecode(__LINE__, {"hello"});
263   TestEncodeDecode(__LINE__,
264                    {"the", "quick", "brown", "fox", "jumped", "over"});
265 
266   string big(1000, 'a');
267   TestEncodeDecode(__LINE__, {"small", big, "small2"});
268 }
269 
TEST(CAPI,SessionOptions)270 TEST(CAPI, SessionOptions) {
271   TF_SessionOptions* opt = TF_NewSessionOptions();
272   TF_DeleteSessionOptions(opt);
273 }
274 
TEST(CAPI,DeprecatedSession)275 TEST(CAPI, DeprecatedSession) {
276   TF_Status* s = TF_NewStatus();
277   TF_SessionOptions* opt = TF_NewSessionOptions();
278   TF_DeprecatedSession* session = TF_NewDeprecatedSession(opt, s);
279   TF_DeleteSessionOptions(opt);
280   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
281 
282   TF_Buffer* run_options = TF_NewBufferFromString("", 0);
283   TF_Buffer* run_metadata = TF_NewBuffer();
284   TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0,
285          nullptr, 0, run_metadata, s);
286   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
287   EXPECT_EQ("Session was not created with a graph before Run()!",
288             string(TF_Message(s)));
289   TF_DeleteBuffer(run_metadata);
290   TF_DeleteBuffer(run_options);
291 
292   TF_DeleteDeprecatedSession(session, s);
293   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
294 
295   TF_DeleteStatus(s);
296 }
297 
TEST(CAPI,DataTypeEnum)298 TEST(CAPI, DataTypeEnum) {
299   EXPECT_EQ(TF_FLOAT, static_cast<TF_DataType>(tensorflow::DT_FLOAT));
300   EXPECT_EQ(TF_DOUBLE, static_cast<TF_DataType>(tensorflow::DT_DOUBLE));
301   EXPECT_EQ(TF_INT32, static_cast<TF_DataType>(tensorflow::DT_INT32));
302   EXPECT_EQ(TF_UINT8, static_cast<TF_DataType>(tensorflow::DT_UINT8));
303   EXPECT_EQ(TF_INT16, static_cast<TF_DataType>(tensorflow::DT_INT16));
304   EXPECT_EQ(TF_INT8, static_cast<TF_DataType>(tensorflow::DT_INT8));
305   EXPECT_EQ(TF_STRING, static_cast<TF_DataType>(tensorflow::DT_STRING));
306   EXPECT_EQ(TF_COMPLEX64, static_cast<TF_DataType>(tensorflow::DT_COMPLEX64));
307   EXPECT_EQ(TF_COMPLEX, TF_COMPLEX64);
308   EXPECT_EQ(TF_INT64, static_cast<TF_DataType>(tensorflow::DT_INT64));
309   EXPECT_EQ(TF_BOOL, static_cast<TF_DataType>(tensorflow::DT_BOOL));
310   EXPECT_EQ(TF_QINT8, static_cast<TF_DataType>(tensorflow::DT_QINT8));
311   EXPECT_EQ(TF_QUINT8, static_cast<TF_DataType>(tensorflow::DT_QUINT8));
312   EXPECT_EQ(TF_QINT32, static_cast<TF_DataType>(tensorflow::DT_QINT32));
313   EXPECT_EQ(TF_BFLOAT16, static_cast<TF_DataType>(tensorflow::DT_BFLOAT16));
314   EXPECT_EQ(TF_QINT16, static_cast<TF_DataType>(tensorflow::DT_QINT16));
315   EXPECT_EQ(TF_QUINT16, static_cast<TF_DataType>(tensorflow::DT_QUINT16));
316   EXPECT_EQ(TF_UINT16, static_cast<TF_DataType>(tensorflow::DT_UINT16));
317   EXPECT_EQ(TF_COMPLEX128, static_cast<TF_DataType>(tensorflow::DT_COMPLEX128));
318   EXPECT_EQ(TF_HALF, static_cast<TF_DataType>(tensorflow::DT_HALF));
319   EXPECT_EQ(TF_DataTypeSize(TF_DOUBLE),
320             tensorflow::DataTypeSize(tensorflow::DT_DOUBLE));
321   EXPECT_EQ(TF_DataTypeSize(TF_STRING),
322             tensorflow::DataTypeSize(tensorflow::DT_STRING));
323   // Test with invalid type; should always return 0 as documented
324   EXPECT_EQ(TF_DataTypeSize(static_cast<TF_DataType>(0)), 0);
325 }
326 
TEST(CAPI,StatusEnum)327 TEST(CAPI, StatusEnum) {
328   EXPECT_EQ(TF_OK, static_cast<TF_Code>(tensorflow::error::OK));
329   EXPECT_EQ(TF_CANCELLED, static_cast<TF_Code>(tensorflow::error::CANCELLED));
330   EXPECT_EQ(TF_UNKNOWN, static_cast<TF_Code>(tensorflow::error::UNKNOWN));
331   EXPECT_EQ(TF_INVALID_ARGUMENT,
332             static_cast<TF_Code>(tensorflow::error::INVALID_ARGUMENT));
333   EXPECT_EQ(TF_DEADLINE_EXCEEDED,
334             static_cast<TF_Code>(tensorflow::error::DEADLINE_EXCEEDED));
335   EXPECT_EQ(TF_NOT_FOUND, static_cast<TF_Code>(tensorflow::error::NOT_FOUND));
336   EXPECT_EQ(TF_ALREADY_EXISTS,
337             static_cast<TF_Code>(tensorflow::error::ALREADY_EXISTS));
338   EXPECT_EQ(TF_PERMISSION_DENIED,
339             static_cast<TF_Code>(tensorflow::error::PERMISSION_DENIED));
340   EXPECT_EQ(TF_UNAUTHENTICATED,
341             static_cast<TF_Code>(tensorflow::error::UNAUTHENTICATED));
342   EXPECT_EQ(TF_RESOURCE_EXHAUSTED,
343             static_cast<TF_Code>(tensorflow::error::RESOURCE_EXHAUSTED));
344   EXPECT_EQ(TF_FAILED_PRECONDITION,
345             static_cast<TF_Code>(tensorflow::error::FAILED_PRECONDITION));
346   EXPECT_EQ(TF_ABORTED, static_cast<TF_Code>(tensorflow::error::ABORTED));
347   EXPECT_EQ(TF_OUT_OF_RANGE,
348             static_cast<TF_Code>(tensorflow::error::OUT_OF_RANGE));
349   EXPECT_EQ(TF_UNIMPLEMENTED,
350             static_cast<TF_Code>(tensorflow::error::UNIMPLEMENTED));
351   EXPECT_EQ(TF_INTERNAL, static_cast<TF_Code>(tensorflow::error::INTERNAL));
352   EXPECT_EQ(TF_UNAVAILABLE,
353             static_cast<TF_Code>(tensorflow::error::UNAVAILABLE));
354   EXPECT_EQ(TF_DATA_LOSS, static_cast<TF_Code>(tensorflow::error::DATA_LOSS));
355 }
356 
TEST(CAPI,GetAllOpList)357 TEST(CAPI, GetAllOpList) {
358   TF_Buffer* buf = TF_GetAllOpList();
359   tensorflow::OpList op_list;
360   EXPECT_TRUE(op_list.ParseFromArray(buf->data, buf->length));
361   EXPECT_GT(op_list.op_size(), 0);
362   TF_DeleteBuffer(buf);
363 }
364 
TEST(CAPI,SetShape)365 TEST(CAPI, SetShape) {
366   TF_Status* s = TF_NewStatus();
367   TF_Graph* graph = TF_NewGraph();
368 
369   TF_Operation* feed = Placeholder(graph, s);
370   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
371   TF_Output feed_out_0 = TF_Output{feed, 0};
372   int num_dims;
373 
374   // Fetch the shape, it should be completely unknown.
375   num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
376   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
377   EXPECT_EQ(-1, num_dims);
378 
379   // Set the shape to be unknown, expect no change.
380   TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
381   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
382   num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
383   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
384   EXPECT_EQ(-1, num_dims);
385 
386   // Set the shape to be 2 x Unknown
387   int64_t dims[] = {2, -1};
388   TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
389   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
390 
391   // Fetch the shape and validate it is 2 by -1.
392   num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s);
393   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
394   EXPECT_EQ(2, num_dims);
395 
396   // Resize the dimension vector appropriately.
397   int64_t returned_dims[2];
398   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
399   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
400   EXPECT_EQ(dims[0], returned_dims[0]);
401   EXPECT_EQ(dims[1], returned_dims[1]);
402 
403   // Set to a new valid shape: [2, 3]
404   dims[1] = 3;
405   TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
406   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
407 
408   // Fetch and see that the new value is returned.
409   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
410   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
411   EXPECT_EQ(dims[0], returned_dims[0]);
412   EXPECT_EQ(dims[1], returned_dims[1]);
413 
414   // Try to set 'unknown' with unknown rank on the shape and see that
415   // it doesn't change.
416   TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s);
417   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
418   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
419   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
420   EXPECT_EQ(2, num_dims);
421   EXPECT_EQ(2, returned_dims[0]);
422   EXPECT_EQ(3, returned_dims[1]);
423 
424   // Try to set 'unknown' with same rank on the shape and see that
425   // it doesn't change.
426   dims[0] = -1;
427   dims[1] = -1;
428   TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
429   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
430   // Fetch and see that the new value is returned.
431   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
432   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
433   EXPECT_EQ(2, num_dims);
434   EXPECT_EQ(2, returned_dims[0]);
435   EXPECT_EQ(3, returned_dims[1]);
436 
437   // Try to fetch a shape with the wrong num_dims
438   TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
439   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
440 
441   // Try to set an invalid shape (cannot change 2x3 to a 2x5).
442   dims[1] = 5;
443   TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
444   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s);
445 
446   // Test for a scalar.
447   TF_Operation* three = ScalarConst(3, graph, s);
448   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
449   TF_Output three_out_0 = TF_Output{three, 0};
450 
451   num_dims = TF_GraphGetTensorNumDims(graph, three_out_0, s);
452   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
453   EXPECT_EQ(0, num_dims);
454   TF_GraphGetTensorShape(graph, three_out_0, returned_dims, num_dims, s);
455   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
456 
457   // Clean up
458   TF_DeleteGraph(graph);
459   TF_DeleteStatus(s);
460 }
461 
TEST(CAPI,Graph)462 TEST(CAPI, Graph) {
463   TF_Status* s = TF_NewStatus();
464   TF_Graph* graph = TF_NewGraph();
465 
466   // Make a placeholder operation.
467   TF_Operation* feed = Placeholder(graph, s);
468   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
469 
470   // Test TF_Operation*() query functions.
471   EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
472   EXPECT_EQ(string("Placeholder"), string(TF_OperationOpType(feed)));
473   EXPECT_EQ(string(""), string(TF_OperationDevice(feed)));
474   EXPECT_EQ(1, TF_OperationNumOutputs(feed));
475   EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{feed, 0}));
476   EXPECT_EQ(1, TF_OperationOutputListLength(feed, "output", s));
477   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
478   EXPECT_EQ(0, TF_OperationNumInputs(feed));
479   EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{feed, 0}));
480   EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
481   EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
482 
483   tensorflow::AttrValue attr_value;
484   ASSERT_TRUE(GetAttrValue(feed, "dtype", &attr_value, s)) << TF_Message(s);
485   EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
486 
487   // Test not found errors in TF_Operation*() query functions.
488   EXPECT_EQ(-1, TF_OperationOutputListLength(feed, "bogus", s));
489   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
490 
491   ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s));
492   EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."),
493             string(TF_Message(s)));
494 
495   // Make a constant oper with the scalar "3".
496   TF_Operation* three = ScalarConst(3, graph, s);
497   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
498 
499   // Add oper.
500   TF_Operation* add = Add(feed, three, graph, s);
501   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
502 
503   // Test TF_Operation*() query functions.
504   EXPECT_EQ(string("add"), string(TF_OperationName(add)));
505   EXPECT_EQ(string("AddN"), string(TF_OperationOpType(add)));
506   EXPECT_EQ(string(""), string(TF_OperationDevice(add)));
507   EXPECT_EQ(1, TF_OperationNumOutputs(add));
508   EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Output{add, 0}));
509   EXPECT_EQ(1, TF_OperationOutputListLength(add, "sum", s));
510   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
511   EXPECT_EQ(2, TF_OperationNumInputs(add));
512   EXPECT_EQ(2, TF_OperationInputListLength(add, "inputs", s));
513   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
514   EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 0}));
515   EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Input{add, 1}));
516   TF_Output add_in_0 = TF_OperationInput(TF_Input{add, 0});
517   EXPECT_EQ(feed, add_in_0.oper);
518   EXPECT_EQ(0, add_in_0.index);
519   TF_Output add_in_1 = TF_OperationInput(TF_Input{add, 1});
520   EXPECT_EQ(three, add_in_1.oper);
521   EXPECT_EQ(0, add_in_1.index);
522   EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Output{add, 0}));
523   EXPECT_EQ(0, TF_OperationNumControlInputs(add));
524   EXPECT_EQ(0, TF_OperationNumControlOutputs(add));
525 
526   ASSERT_TRUE(GetAttrValue(add, "T", &attr_value, s)) << TF_Message(s);
527   EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
528   ASSERT_TRUE(GetAttrValue(add, "N", &attr_value, s)) << TF_Message(s);
529   EXPECT_EQ(attr_value.i(), 2);
530 
531   // Placeholder oper now has a consumer.
532   ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{feed, 0}));
533   TF_Input feed_port;
534   EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Output{feed, 0}, &feed_port, 1));
535   EXPECT_EQ(add, feed_port.oper);
536   EXPECT_EQ(0, feed_port.index);
537 
538   // The scalar const oper also has a consumer.
539   ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Output{three, 0}));
540   TF_Input three_port;
541   EXPECT_EQ(1,
542             TF_OperationOutputConsumers(TF_Output{three, 0}, &three_port, 1));
543   EXPECT_EQ(add, three_port.oper);
544   EXPECT_EQ(1, three_port.index);
545 
546   // Serialize to GraphDef.
547   GraphDef graph_def;
548   ASSERT_TRUE(GetGraphDef(graph, &graph_def));
549 
550   // Validate GraphDef is what we expect.
551   bool found_placeholder = false;
552   bool found_scalar_const = false;
553   bool found_add = false;
554   for (const auto& n : graph_def.node()) {
555     if (IsPlaceholder(n)) {
556       EXPECT_FALSE(found_placeholder);
557       found_placeholder = true;
558     } else if (IsScalarConst(n, 3)) {
559       EXPECT_FALSE(found_scalar_const);
560       found_scalar_const = true;
561     } else if (IsAddN(n, 2)) {
562       EXPECT_FALSE(found_add);
563       found_add = true;
564     } else {
565       ADD_FAILURE() << "Unexpected NodeDef: " << n.DebugString();
566     }
567   }
568   EXPECT_TRUE(found_placeholder);
569   EXPECT_TRUE(found_scalar_const);
570   EXPECT_TRUE(found_add);
571 
572   // Add another oper to the graph.
573   TF_Operation* neg = Neg(add, graph, s);
574   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
575 
576   // Serialize to NodeDef.
577   NodeDef node_def;
578   ASSERT_TRUE(GetNodeDef(neg, &node_def));
579 
580   // Validate NodeDef is what we expect.
581   EXPECT_TRUE(IsNeg(node_def, "add"));
582 
583   // Serialize to GraphDef.
584   GraphDef graph_def2;
585   ASSERT_TRUE(GetGraphDef(graph, &graph_def2));
586 
587   // Compare with first GraphDef + added NodeDef.
588   NodeDef* added_node = graph_def.add_node();
589   *added_node = node_def;
590   EXPECT_EQ(graph_def.DebugString(), graph_def2.DebugString());
591 
592   // Look up some nodes by name.
593   TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg");
594   EXPECT_TRUE(neg == neg2);
595   NodeDef node_def2;
596   ASSERT_TRUE(GetNodeDef(neg2, &node_def2));
597   EXPECT_EQ(node_def.DebugString(), node_def2.DebugString());
598 
599   TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed");
600   EXPECT_TRUE(feed == feed2);
601   ASSERT_TRUE(GetNodeDef(feed, &node_def));
602   ASSERT_TRUE(GetNodeDef(feed2, &node_def2));
603   EXPECT_EQ(node_def.DebugString(), node_def2.DebugString());
604 
605   // Test iterating through the nodes of a graph.
606   found_placeholder = false;
607   found_scalar_const = false;
608   found_add = false;
609   bool found_neg = false;
610   size_t pos = 0;
611   TF_Operation* oper;
612   while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
613     if (oper == feed) {
614       EXPECT_FALSE(found_placeholder);
615       found_placeholder = true;
616     } else if (oper == three) {
617       EXPECT_FALSE(found_scalar_const);
618       found_scalar_const = true;
619     } else if (oper == add) {
620       EXPECT_FALSE(found_add);
621       found_add = true;
622     } else if (oper == neg) {
623       EXPECT_FALSE(found_neg);
624       found_neg = true;
625     } else {
626       ASSERT_TRUE(GetNodeDef(oper, &node_def));
627       ADD_FAILURE() << "Unexpected Node: " << node_def.DebugString();
628     }
629   }
630   EXPECT_TRUE(found_placeholder);
631   EXPECT_TRUE(found_scalar_const);
632   EXPECT_TRUE(found_add);
633   EXPECT_TRUE(found_neg);
634 
635   // Clean up
636   TF_DeleteGraph(graph);
637   TF_DeleteStatus(s);
638 }
639 
TEST(CAPI,UpdateEdge)640 TEST(CAPI, UpdateEdge) {
641   TF_Status* s = TF_NewStatus();
642   TF_Graph* graph = TF_NewGraph();
643 
644   // Make two scalar constants.
645   TF_Operation* one = ScalarConst(1, graph, s, "one");
646   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
647 
648   TF_Operation* two = ScalarConst(2, graph, s, "two");
649   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
650 
651   // Add oper.
652   TF_Operation* add = Add(one, two, graph, s, "add");
653   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
654 
655   // Add another oper to the graph.
656   TF_Operation* neg = Neg(add, graph, s, "neg");
657   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
658 
659   NodeDef node_def_neg;
660   ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
661   EXPECT_EQ(string("add"), node_def_neg.input(0));
662 
663   // update edge of neg
664   TF_UpdateEdge(graph, TF_Output{one, 0}, TF_Input{neg, 0}, s);
665 
666   ASSERT_TRUE(GetNodeDef(neg, &node_def_neg));
667   EXPECT_EQ(string("one:0"), node_def_neg.input(0));
668 
669   // Clean up
670   TF_DeleteGraph(graph);
671   TF_DeleteStatus(s);
672 }
673 
674 /*
675 TODO(skyewm): this test currently DCHECKs, change to bad status
676 
677 TEST(CAPI, InputFromDifferentGraphError) {
678   TF_Status* s = TF_NewStatus();
679   TF_Graph* g1 = TF_NewGraph();
680   TF_Graph* g2 = TF_NewGraph();
681 
682   TF_Operation* feed = Placeholder(g1, s);
683   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
684 
685   // Attempt to create node in g2 with input from g1
686   Neg(feed, g2, s);
687   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
688   EXPECT_STREQ("foo", TF_Message(s));
689 
690   TF_DeleteGraph(g1);
691   TF_DeleteGraph(g2);
692   TF_DeleteStatus(s);
693 }
694 */
695 
TEST(CAPI,ImportGraphDef)696 TEST(CAPI, ImportGraphDef) {
697   TF_Status* s = TF_NewStatus();
698   TF_Graph* graph = TF_NewGraph();
699 
700   // Create a simple graph.
701   Placeholder(graph, s);
702   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
703   ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
704   TF_Operation* oper = ScalarConst(3, graph, s);
705   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
706   ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
707   Neg(oper, graph, s);
708   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
709   ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
710 
711   // Export to a GraphDef.
712   TF_Buffer* graph_def = TF_NewBuffer();
713   TF_GraphToGraphDef(graph, graph_def, s);
714   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
715 
716   // Import it, with a prefix, in a fresh graph.
717   TF_DeleteGraph(graph);
718   graph = TF_NewGraph();
719   TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
720   TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
721   TF_GraphImportGraphDef(graph, graph_def, opts, s);
722   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
723 
724   TF_Operation* scalar = TF_GraphOperationByName(graph, "imported/scalar");
725   TF_Operation* feed = TF_GraphOperationByName(graph, "imported/feed");
726   TF_Operation* neg = TF_GraphOperationByName(graph, "imported/neg");
727   ASSERT_TRUE(scalar != nullptr);
728   ASSERT_TRUE(feed != nullptr);
729   ASSERT_TRUE(neg != nullptr);
730 
731   // Test basic structure of the imported graph.
732   EXPECT_EQ(0, TF_OperationNumInputs(scalar));
733   EXPECT_EQ(0, TF_OperationNumInputs(feed));
734   ASSERT_EQ(1, TF_OperationNumInputs(neg));
735   TF_Output neg_input = TF_OperationInput({neg, 0});
736   EXPECT_EQ(scalar, neg_input.oper);
737   EXPECT_EQ(0, neg_input.index);
738 
739   // Test that we can't see control edges involving the source and sink nodes.
740   TF_Operation* control_ops[100];
741   EXPECT_EQ(0, TF_OperationNumControlInputs(scalar));
742   EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100));
743   EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar));
744   EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100));
745 
746   EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
747   EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100));
748   EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
749   EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100));
750 
751   EXPECT_EQ(0, TF_OperationNumControlInputs(neg));
752   EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100));
753   EXPECT_EQ(0, TF_OperationNumControlOutputs(neg));
754   EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100));
755 
756   // Import it again, with an input mapping, return outputs, and a return
757   // operation, into the same graph.
758   TF_DeleteImportGraphDefOptions(opts);
759   opts = TF_NewImportGraphDefOptions();
760   TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
761   TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
762   TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
763   TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
764   EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
765   TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
766   EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts));
767   TF_ImportGraphDefResults* results =
768       TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
769   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
770 
771   TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar");
772   TF_Operation* feed2 = TF_GraphOperationByName(graph, "imported2/feed");
773   TF_Operation* neg2 = TF_GraphOperationByName(graph, "imported2/neg");
774   ASSERT_TRUE(scalar2 != nullptr);
775   ASSERT_TRUE(feed2 != nullptr);
776   ASSERT_TRUE(neg2 != nullptr);
777 
778   // Check input mapping
779   neg_input = TF_OperationInput({neg, 0});
780   EXPECT_EQ(scalar, neg_input.oper);
781   EXPECT_EQ(0, neg_input.index);
782 
783   // Check return outputs
784   TF_Output* return_outputs;
785   int num_return_outputs;
786   TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs,
787                                         &return_outputs);
788   ASSERT_EQ(2, num_return_outputs);
789   EXPECT_EQ(feed2, return_outputs[0].oper);
790   EXPECT_EQ(0, return_outputs[0].index);
791   EXPECT_EQ(scalar, return_outputs[1].oper);  // remapped
792   EXPECT_EQ(0, return_outputs[1].index);
793 
794   // Check return operation
795   TF_Operation** return_opers;
796   int num_return_opers;
797   TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers,
798                                            &return_opers);
799   ASSERT_EQ(1, num_return_opers);
800   EXPECT_EQ(scalar2, return_opers[0]);  // not remapped
801 
802   TF_DeleteImportGraphDefResults(results);
803 
804   // Import again, with control dependencies, into the same graph.
805   TF_DeleteImportGraphDefOptions(opts);
806   opts = TF_NewImportGraphDefOptions();
807   TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
808   TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
809   TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
810   TF_GraphImportGraphDef(graph, graph_def, opts, s);
811   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
812 
813   TF_Operation* scalar3 = TF_GraphOperationByName(graph, "imported3/scalar");
814   TF_Operation* feed3 = TF_GraphOperationByName(graph, "imported3/feed");
815   TF_Operation* neg3 = TF_GraphOperationByName(graph, "imported3/neg");
816   ASSERT_TRUE(scalar3 != nullptr);
817   ASSERT_TRUE(feed3 != nullptr);
818   ASSERT_TRUE(neg3 != nullptr);
819 
820   // Check that newly-imported scalar and feed have control deps (neg3 will
821   // inherit them from input)
822   TF_Operation* control_inputs[100];
823   int num_control_inputs = TF_OperationGetControlInputs(
824       scalar3, control_inputs, TF_OperationNumControlInputs(scalar3));
825   ASSERT_EQ(2, num_control_inputs);
826   EXPECT_EQ(feed, control_inputs[0]);
827   EXPECT_EQ(feed2, control_inputs[1]);
828 
829   num_control_inputs = TF_OperationGetControlInputs(
830       feed3, control_inputs, TF_OperationNumControlInputs(feed3));
831   ASSERT_EQ(2, num_control_inputs);
832   EXPECT_EQ(feed, control_inputs[0]);
833   EXPECT_EQ(feed2, control_inputs[1]);
834 
835   // Export to a graph def so we can import a graph with control dependencies
836   TF_DeleteBuffer(graph_def);
837   graph_def = TF_NewBuffer();
838   TF_GraphToGraphDef(graph, graph_def, s);
839   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
840 
841   // Import again, with remapped control dependency, into the same graph
842   TF_DeleteImportGraphDefOptions(opts);
843   opts = TF_NewImportGraphDefOptions();
844   TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
845   TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
846   TF_GraphImportGraphDef(graph, graph_def, opts, s);
847   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
848 
849   TF_Operation* scalar4 =
850       TF_GraphOperationByName(graph, "imported4/imported3/scalar");
851   TF_Operation* feed4 =
852       TF_GraphOperationByName(graph, "imported4/imported2/feed");
853 
854   // Check that imported `imported3/scalar` has remapped control dep from
855   // original graph and imported control dep
856   num_control_inputs = TF_OperationGetControlInputs(
857       scalar4, control_inputs, TF_OperationNumControlInputs(scalar4));
858   ASSERT_EQ(2, num_control_inputs);
859   EXPECT_EQ(feed, control_inputs[0]);
860   EXPECT_EQ(feed4, control_inputs[1]);
861 
862   TF_DeleteImportGraphDefOptions(opts);
863   TF_DeleteBuffer(graph_def);
864 
865   // Can add nodes to the imported graph without trouble.
866   Add(feed, scalar, graph, s);
867   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
868 
869   TF_DeleteGraph(graph);
870   TF_DeleteStatus(s);
871 }
872 
TEST(CAPI,ImportGraphDef_WithReturnOutputs)873 TEST(CAPI, ImportGraphDef_WithReturnOutputs) {
874   TF_Status* s = TF_NewStatus();
875   TF_Graph* graph = TF_NewGraph();
876 
877   // Create a graph with two nodes: x and 3
878   Placeholder(graph, s);
879   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
880   ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
881   TF_Operation* oper = ScalarConst(3, graph, s);
882   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
883   ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
884   Neg(oper, graph, s);
885   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
886   ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
887 
888   // Export to a GraphDef.
889   TF_Buffer* graph_def = TF_NewBuffer();
890   TF_GraphToGraphDef(graph, graph_def, s);
891   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
892 
893   // Import it in a fresh graph with return outputs.
894   TF_DeleteGraph(graph);
895   graph = TF_NewGraph();
896   TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
897   TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
898   TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
899   EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts));
900   TF_Output return_outputs[2];
901   TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts,
902                                           return_outputs, 2, s);
903   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
904 
905   TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
906   TF_Operation* feed = TF_GraphOperationByName(graph, "feed");
907   TF_Operation* neg = TF_GraphOperationByName(graph, "neg");
908   ASSERT_TRUE(scalar != nullptr);
909   ASSERT_TRUE(feed != nullptr);
910   ASSERT_TRUE(neg != nullptr);
911 
912   // Check return outputs
913   EXPECT_EQ(feed, return_outputs[0].oper);
914   EXPECT_EQ(0, return_outputs[0].index);
915   EXPECT_EQ(scalar, return_outputs[1].oper);
916   EXPECT_EQ(0, return_outputs[1].index);
917 
918   TF_DeleteImportGraphDefOptions(opts);
919   TF_DeleteBuffer(graph_def);
920   TF_DeleteGraph(graph);
921   TF_DeleteStatus(s);
922 }
923 
TEST(CAPI,ImportGraphDef_MissingUnusedInputMappings)924 TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) {
925   TF_Status* s = TF_NewStatus();
926   TF_Graph* graph = TF_NewGraph();
927 
928   // Create a graph with two nodes: x and 3
929   Placeholder(graph, s);
930   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
931   ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr);
932   TF_Operation* oper = ScalarConst(3, graph, s);
933   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
934   ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr);
935   Neg(oper, graph, s);
936   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
937   ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr);
938 
939   // Export to a GraphDef.
940   TF_Buffer* graph_def = TF_NewBuffer();
941   TF_GraphToGraphDef(graph, graph_def, s);
942   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
943 
944   // Import it in a fresh graph.
945   TF_DeleteGraph(graph);
946   graph = TF_NewGraph();
947   TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions();
948   TF_GraphImportGraphDef(graph, graph_def, opts, s);
949   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
950 
951   TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar");
952 
953   // Import it in a fresh graph with an unused input mapping.
954   TF_DeleteImportGraphDefOptions(opts);
955   opts = TF_NewImportGraphDefOptions();
956   TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
957   TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0});
958   TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0});
959   TF_ImportGraphDefResults* results =
960       TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
961   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
962 
963   // Check unused input mappings
964   int num_unused_input_mappings;
965   const char** src_names;
966   int* src_indexes;
967   TF_ImportGraphDefResultsMissingUnusedInputMappings(
968       results, &num_unused_input_mappings, &src_names, &src_indexes);
969   ASSERT_EQ(1, num_unused_input_mappings);
970   EXPECT_EQ(string("fake"), string(src_names[0]));
971   EXPECT_EQ(0, src_indexes[0]);
972 
973   TF_DeleteImportGraphDefResults(results);
974   TF_DeleteImportGraphDefOptions(opts);
975   TF_DeleteBuffer(graph_def);
976   TF_DeleteGraph(graph);
977   TF_DeleteStatus(s);
978 }
979 
TEST(CAPI,Session)980 TEST(CAPI, Session) {
981   TF_Status* s = TF_NewStatus();
982   TF_Graph* graph = TF_NewGraph();
983 
984   // Make a placeholder operation.
985   TF_Operation* feed = Placeholder(graph, s);
986   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
987 
988   // Make a constant operation with the scalar "2".
989   TF_Operation* two = ScalarConst(2, graph, s);
990   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
991 
992   // Add operation.
993   TF_Operation* add = Add(feed, two, graph, s);
994   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
995 
996   // Create a session for this graph.
997   CSession csession(graph, s);
998   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
999 
1000   // Run the graph.
1001   csession.SetInputs({{feed, Int32Tensor(3)}});
1002   csession.SetOutputs({add});
1003   csession.Run(s);
1004   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1005   TF_Tensor* out = csession.output_tensor(0);
1006   ASSERT_TRUE(out != nullptr);
1007   EXPECT_EQ(TF_INT32, TF_TensorType(out));
1008   EXPECT_EQ(0, TF_NumDims(out));  // scalar
1009   ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1010   int32* output_contents = static_cast<int32*>(TF_TensorData(out));
1011   EXPECT_EQ(3 + 2, *output_contents);
1012 
1013   // Add another operation to the graph.
1014   TF_Operation* neg = Neg(add, graph, s);
1015   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1016 
1017   // Run up to the new operation.
1018   csession.SetInputs({{feed, Int32Tensor(7)}});
1019   csession.SetOutputs({neg});
1020   csession.Run(s);
1021   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1022   out = csession.output_tensor(0);
1023   ASSERT_TRUE(out != nullptr);
1024   EXPECT_EQ(TF_INT32, TF_TensorType(out));
1025   EXPECT_EQ(0, TF_NumDims(out));  // scalar
1026   ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1027   output_contents = static_cast<int32*>(TF_TensorData(out));
1028   EXPECT_EQ(-(7 + 2), *output_contents);
1029 
1030   // Clean up
1031   csession.CloseAndDelete(s);
1032   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1033   TF_DeleteGraph(graph);
1034   TF_DeleteStatus(s);
1035 }
1036 
1037 // If `device` is non-empty, run Min op on that device.
1038 // Otherwise run it on the default device (CPU).
RunMinTest(const string & device,bool use_XLA)1039 void RunMinTest(const string& device, bool use_XLA) {
1040   TF_Status* s = TF_NewStatus();
1041   TF_Graph* graph = TF_NewGraph();
1042 
1043   // Make a placeholder operation.
1044   TF_Operation* feed = Placeholder(graph, s);
1045   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1046 
1047   // Make a constant operation with the scalar "0", for axis.
1048   TF_Operation* one = ScalarConst(0, graph, s);
1049   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1050 
1051   // Create a session for this graph.
1052   CSession csession(graph, s, use_XLA);
1053   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1054 
1055   if (!device.empty()) {
1056     LOG(INFO) << "Setting op Min on device " << device;
1057   }
1058   TF_Operation* min = MinWithDevice(feed, one, graph, device, s);
1059   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1060 
1061   // Run the graph.
1062   csession.SetInputs({{feed, Int32Tensor({3, 2, 5})}});
1063   csession.SetOutputs({min});
1064   csession.Run(s);
1065   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1066   TF_Tensor* out = csession.output_tensor(0);
1067   ASSERT_TRUE(out != nullptr);
1068   EXPECT_EQ(TF_INT32, TF_TensorType(out));
1069   EXPECT_EQ(0, TF_NumDims(out));  // scalar
1070   ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
1071   int32* output_contents = static_cast<int32*>(TF_TensorData(out));
1072   EXPECT_EQ(2, *output_contents);
1073 
1074   // Clean up
1075   csession.CloseAndDelete(s);
1076   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1077   TF_DeleteGraph(graph);
1078   TF_DeleteStatus(s);
1079 }
1080 
TEST(CAPI,Session_Min_CPU)1081 TEST(CAPI, Session_Min_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/false); }
1082 
TEST(CAPI,Session_Min_XLA_CPU)1083 TEST(CAPI, Session_Min_XLA_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/true); }
1084 
TEST(CAPI,Session_Min_GPU)1085 TEST(CAPI, Session_Min_GPU) {
1086   const string gpu_device = GPUDeviceName();
1087   // Skip this test if no GPU is available.
1088   if (gpu_device.empty()) return;
1089 
1090   RunMinTest(gpu_device, /*use_XLA=*/false);
1091 }
1092 
TEST(CAPI,Session_Min_XLA_GPU)1093 TEST(CAPI, Session_Min_XLA_GPU) {
1094   const string gpu_device = GPUDeviceName();
1095   // Skip this test if no GPU is available.
1096   if (gpu_device.empty()) return;
1097 
1098   RunMinTest(gpu_device, /*use_XLA=*/true);
1099 }
1100 
TEST(CAPI,SessionPRun)1101 TEST(CAPI, SessionPRun) {
1102   TF_Status* s = TF_NewStatus();
1103   TF_Graph* graph = TF_NewGraph();
1104 
1105   // Construct the graph: A + 2 + B
1106   TF_Operation* a = Placeholder(graph, s, "A");
1107   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1108 
1109   TF_Operation* b = Placeholder(graph, s, "B");
1110   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1111 
1112   TF_Operation* two = ScalarConst(2, graph, s);
1113   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1114 
1115   TF_Operation* plus2 = Add(a, two, graph, s, "plus2");
1116   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1117 
1118   TF_Operation* plusB = Add(plus2, b, graph, s, "plusB");
1119   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1120 
1121   // Setup a session and a partial run handle.  The partial run will allow
1122   // computation of A + 2 + B in two phases (calls to TF_SessionPRun):
1123   // 1. Feed A and get (A+2)
1124   // 2. Feed B and get (A+2)+B
1125   TF_SessionOptions* opts = TF_NewSessionOptions();
1126   TF_Session* sess = TF_NewSession(graph, opts, s);
1127   TF_DeleteSessionOptions(opts);
1128 
1129   TF_Output feeds[] = {TF_Output{a, 0}, TF_Output{b, 0}};
1130   TF_Output fetches[] = {TF_Output{plus2, 0}, TF_Output{plusB, 0}};
1131 
1132   const char* handle = nullptr;
1133   TF_SessionPRunSetup(sess, feeds, TF_ARRAYSIZE(feeds), fetches,
1134                       TF_ARRAYSIZE(fetches), nullptr, 0, &handle, s);
1135   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1136 
1137   // Feed A and fetch A + 2.
1138   TF_Output feeds1[] = {TF_Output{a, 0}};
1139   TF_Output fetches1[] = {TF_Output{plus2, 0}};
1140   TF_Tensor* feedValues1[] = {Int32Tensor(1)};
1141   TF_Tensor* fetchValues1[1];
1142   TF_SessionPRun(sess, handle, feeds1, feedValues1, 1, fetches1, fetchValues1,
1143                  1, nullptr, 0, s);
1144   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1145   EXPECT_EQ(3, *(static_cast<int32*>(TF_TensorData(fetchValues1[0]))));
1146   TF_DeleteTensor(feedValues1[0]);
1147   TF_DeleteTensor(fetchValues1[0]);
1148 
1149   // Feed B and fetch (A + 2) + B.
1150   TF_Output feeds2[] = {TF_Output{b, 0}};
1151   TF_Output fetches2[] = {TF_Output{plusB, 0}};
1152   TF_Tensor* feedValues2[] = {Int32Tensor(4)};
1153   TF_Tensor* fetchValues2[1];
1154   TF_SessionPRun(sess, handle, feeds2, feedValues2, 1, fetches2, fetchValues2,
1155                  1, nullptr, 0, s);
1156   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1157   EXPECT_EQ(7, *(static_cast<int32*>(TF_TensorData(fetchValues2[0]))));
1158   TF_DeleteTensor(feedValues2[0]);
1159   TF_DeleteTensor(fetchValues2[0]);
1160 
1161   // Clean up.
1162   TF_DeletePRunHandle(handle);
1163   TF_DeleteSession(sess, s);
1164   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1165   TF_DeleteGraph(graph);
1166   TF_DeleteStatus(s);
1167 }
1168 
TEST(CAPI,ShapeInferenceError)1169 TEST(CAPI, ShapeInferenceError) {
1170   // TF_FinishOperation should fail if the shape of the added operation cannot
1171   // be inferred.
1172   TF_Status* status = TF_NewStatus();
1173   TF_Graph* graph = TF_NewGraph();
1174 
1175   // Create this failure by trying to add two nodes with incompatible shapes
1176   // (A tensor with shape [2] and a tensor with shape [3] cannot be added).
1177   const char data[] = {1, 2, 3};
1178   const int64_t vec2_dims[] = {2};
1179   unique_tensor_ptr vec2_tensor(
1180       Int8Tensor(vec2_dims, TF_ARRAYSIZE(vec2_dims), data), TF_DeleteTensor);
1181   TF_Operation* vec2 = Const(vec2_tensor.get(), graph, status, "vec2");
1182   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1183 
1184   const int64_t vec3_dims[] = {3};
1185   unique_tensor_ptr vec3_tensor(
1186       Int8Tensor(vec3_dims, TF_ARRAYSIZE(vec3_dims), data), TF_DeleteTensor);
1187   TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3");
1188   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1189 
1190   TF_Operation* add = AddNoCheck(vec2, vec3, graph, status);
1191   ASSERT_NE(TF_OK, TF_GetCode(status));
1192   ASSERT_TRUE(add == nullptr);
1193 
1194   TF_DeleteGraph(graph);
1195   TF_DeleteStatus(status);
1196 }
1197 
TEST(CAPI,GetOpDef)1198 TEST(CAPI, GetOpDef) {
1199   TF_Status* status = TF_NewStatus();
1200   TF_Graph* graph = TF_NewGraph();
1201   TF_Buffer* buffer = TF_NewBuffer();
1202 
1203   TF_GraphGetOpDef(graph, "Add", buffer, status);
1204   ASSERT_EQ(TF_OK, TF_GetCode(status));
1205   const OpDef* expected_op_def;
1206   TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def));
1207   string expected_serialized;
1208   expected_op_def->SerializeToString(&expected_serialized);
1209   string actual_string(reinterpret_cast<const char*>(buffer->data),
1210                        buffer->length);
1211   EXPECT_EQ(expected_serialized, actual_string);
1212 
1213   TF_GraphGetOpDef(graph, "MyFakeOp", buffer, status);
1214   EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status));
1215   ExpectHasSubstr(TF_Message(status),
1216                   "Op type not registered 'MyFakeOp' in binary");
1217 
1218   TF_DeleteBuffer(buffer);
1219   TF_DeleteGraph(graph);
1220   TF_DeleteStatus(status);
1221 }
1222 
StringVectorToArrays(const std::vector<string> & v,std::unique_ptr<const void * []> * ptrs,std::unique_ptr<size_t[]> * lens)1223 void StringVectorToArrays(const std::vector<string>& v,
1224                           std::unique_ptr<const void*[]>* ptrs,
1225                           std::unique_ptr<size_t[]>* lens) {
1226   ptrs->reset(new const void*[v.size()]);
1227   lens->reset(new size_t[v.size()]);
1228   for (size_t i = 0; i < v.size(); ++i) {
1229     (*ptrs)[i] = v[i].data();
1230     (*lens)[i] = v[i].size();
1231   }
1232 }
1233 
1234 class CApiColocationTest : public ::testing::Test {
1235  protected:
CApiColocationTest()1236   CApiColocationTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
1237 
SetUp()1238   void SetUp() override {
1239     feed1_ = Placeholder(graph_, s_, "feed1");
1240     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1241 
1242     feed2_ = Placeholder(graph_, s_, "feed2");
1243     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1244 
1245     constant_ = ScalarConst(10, graph_, s_);
1246     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1247 
1248     desc_ = TF_NewOperation(graph_, "AddN", "add");
1249     TF_Output inputs[] = {{feed1_, 0}, {constant_, 0}};
1250     TF_AddInputList(desc_, inputs, TF_ARRAYSIZE(inputs));
1251   }
1252 
~CApiColocationTest()1253   ~CApiColocationTest() override {
1254     TF_DeleteGraph(graph_);
1255     TF_DeleteStatus(s_);
1256   }
1257 
SetViaStringList(TF_OperationDescription * desc,const std::vector<string> & list)1258   void SetViaStringList(TF_OperationDescription* desc,
1259                         const std::vector<string>& list) {
1260     std::unique_ptr<const void*[]> list_ptrs;
1261     std::unique_ptr<size_t[]> list_lens;
1262     StringVectorToArrays(list, &list_ptrs, &list_lens);
1263     TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(),
1264                          list_lens.get(), list.size());
1265   }
1266 
SetViaProto(TF_OperationDescription * desc,const std::vector<string> & list)1267   void SetViaProto(TF_OperationDescription* desc,
1268                    const std::vector<string>& list) {
1269     tensorflow::AttrValue attr;
1270     for (const string& v : list) {
1271       attr.mutable_list()->add_s(v);
1272     }
1273     string bytes;
1274     attr.SerializeToString(&bytes);
1275     TF_SetAttrValueProto(desc, tensorflow::kColocationAttrName, bytes.data(),
1276                          bytes.size(), s_);
1277     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1278   }
1279 
VerifyCollocation(TF_Operation * op,const std::vector<string> & expected)1280   void VerifyCollocation(TF_Operation* op,
1281                          const std::vector<string>& expected) {
1282     TF_AttrMetadata m =
1283         TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
1284     if (expected.empty()) {
1285       ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1286       EXPECT_EQ("Operation 'add' has no attr named '_class'.",
1287                 string(TF_Message(s_)));
1288       return;
1289     }
1290     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1291     EXPECT_EQ(1, m.is_list);
1292     EXPECT_EQ(expected.size(), m.list_size);
1293     EXPECT_EQ(TF_ATTR_STRING, m.type);
1294     std::vector<void*> values(expected.size());
1295     std::vector<size_t> lens(expected.size());
1296     std::unique_ptr<char[]> storage(new char[m.total_size]);
1297     TF_OperationGetAttrStringList(op, tensorflow::kColocationAttrName,
1298                                   values.data(), lens.data(), expected.size(),
1299                                   storage.get(), m.total_size, s_);
1300     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1301     for (int i = 0; i < expected.size(); ++i) {
1302       EXPECT_EQ(expected[i],
1303                 string(static_cast<const char*>(values[i]), lens[i]));
1304     }
1305   }
1306 
FinishAndVerify(TF_OperationDescription * desc,const std::vector<string> & expected)1307   void FinishAndVerify(TF_OperationDescription* desc,
1308                        const std::vector<string>& expected) {
1309     TF_Operation* op = TF_FinishOperation(desc_, s_);
1310     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1311     VerifyCollocation(op, expected);
1312   }
1313 
1314   TF_Status* s_;
1315   TF_Graph* graph_;
1316   TF_Operation* feed1_;
1317   TF_Operation* feed2_;
1318   TF_Operation* constant_;
1319   TF_OperationDescription* desc_;
1320 };
1321 
TEST_F(CApiColocationTest,ColocateWith)1322 TEST_F(CApiColocationTest, ColocateWith) {
1323   TF_ColocateWith(desc_, feed1_);
1324   FinishAndVerify(desc_, {"loc:@feed1"});
1325 }
1326 
TEST_F(CApiColocationTest,StringList)1327 TEST_F(CApiColocationTest, StringList) {
1328   SetViaStringList(desc_, {"loc:@feed1"});
1329   FinishAndVerify(desc_, {"loc:@feed1"});
1330 }
1331 
TEST_F(CApiColocationTest,Proto)1332 TEST_F(CApiColocationTest, Proto) {
1333   SetViaProto(desc_, {"loc:@feed1"});
1334   FinishAndVerify(desc_, {"loc:@feed1"});
1335 }
1336 
TEST_F(CApiColocationTest,ColocateWith_StringList)1337 TEST_F(CApiColocationTest, ColocateWith_StringList) {
1338   TF_ColocateWith(desc_, feed1_);
1339   SetViaStringList(desc_, {"loc:@feed2"});
1340   FinishAndVerify(desc_, {"loc:@feed2"});
1341 }
1342 
TEST_F(CApiColocationTest,ColocateWith_Proto)1343 TEST_F(CApiColocationTest, ColocateWith_Proto) {
1344   TF_ColocateWith(desc_, feed1_);
1345   SetViaProto(desc_, {"loc:@feed2"});
1346   FinishAndVerify(desc_, {"loc:@feed2"});
1347 }
1348 
TEST_F(CApiColocationTest,StringList_ColocateWith)1349 TEST_F(CApiColocationTest, StringList_ColocateWith) {
1350   SetViaStringList(desc_, {"loc:@feed2"});
1351   TF_ColocateWith(desc_, feed1_);
1352   FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1353 }
1354 
TEST_F(CApiColocationTest,Proto_ColocateWith)1355 TEST_F(CApiColocationTest, Proto_ColocateWith) {
1356   SetViaProto(desc_, {"loc:@feed2"});
1357   TF_ColocateWith(desc_, feed1_);
1358   FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1359 }
1360 
TEST_F(CApiColocationTest,ColocateWith_ColocateWith)1361 TEST_F(CApiColocationTest, ColocateWith_ColocateWith) {
1362   TF_ColocateWith(desc_, feed1_);
1363   TF_ColocateWith(desc_, feed2_);
1364   FinishAndVerify(desc_, {"loc:@feed1", "loc:@feed2"});
1365 }
1366 
TEST_F(CApiColocationTest,Proto_StringList)1367 TEST_F(CApiColocationTest, Proto_StringList) {
1368   SetViaProto(desc_, {"loc:@feed1"});
1369   SetViaStringList(desc_, {"loc:@feed2"});
1370   FinishAndVerify(desc_, {"loc:@feed2"});
1371 }
1372 
TEST_F(CApiColocationTest,StringList_Proto)1373 TEST_F(CApiColocationTest, StringList_Proto) {
1374   SetViaStringList(desc_, {"loc:@feed1"});
1375   SetViaProto(desc_, {"loc:@feed2"});
1376   FinishAndVerify(desc_, {"loc:@feed2"});
1377 }
1378 
TEST_F(CApiColocationTest,ClearViaStringList)1379 TEST_F(CApiColocationTest, ClearViaStringList) {
1380   TF_ColocateWith(desc_, feed1_);
1381   SetViaStringList(desc_, {});
1382   FinishAndVerify(desc_, {});
1383 }
1384 
TEST_F(CApiColocationTest,ClearViaProto)1385 TEST_F(CApiColocationTest, ClearViaProto) {
1386   TF_ColocateWith(desc_, feed1_);
1387   SetViaProto(desc_, {});
1388   FinishAndVerify(desc_, {});
1389 }
1390 
TEST(CAPI,SavedModel)1391 TEST(CAPI, SavedModel) {
1392   // Load the saved model.
1393   const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
1394       tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
1395                                "half_plus_two", "00000123"));
1396   TF_SessionOptions* opt = TF_NewSessionOptions();
1397   TF_Buffer* run_options = TF_NewBufferFromString("", 0);
1398   TF_Buffer* metagraph = TF_NewBuffer();
1399   TF_Status* s = TF_NewStatus();
1400   const char* tags[] = {tensorflow::kSavedModelTagServe};
1401   TF_Graph* graph = TF_NewGraph();
1402   TF_Session* session = TF_LoadSessionFromSavedModel(
1403       opt, run_options, saved_model_dir.c_str(), tags, 1, graph, metagraph, s);
1404   TF_DeleteBuffer(run_options);
1405   TF_DeleteSessionOptions(opt);
1406   tensorflow::MetaGraphDef metagraph_def;
1407   metagraph_def.ParseFromArray(metagraph->data, metagraph->length);
1408   TF_DeleteBuffer(metagraph);
1409 
1410   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1411   CSession csession(session);
1412 
1413   // Retrieve the regression signature from meta graph def.
1414   const auto signature_def_map = metagraph_def.signature_def();
1415   const auto signature_def = signature_def_map.at("regress_x_to_y");
1416 
1417   const string input_name =
1418       signature_def.inputs().at(tensorflow::kRegressInputs).name();
1419   const string output_name =
1420       signature_def.outputs().at(tensorflow::kRegressOutputs).name();
1421 
1422   // Write {0, 1, 2, 3} as tensorflow::Example inputs.
1423   Tensor input(tensorflow::DT_STRING, TensorShape({4}));
1424   for (int64_t i = 0; i < input.NumElements(); ++i) {
1425     tensorflow::Example example;
1426     auto* feature_map = example.mutable_features()->mutable_feature();
1427     (*feature_map)["x"].mutable_float_list()->add_value(i);
1428     input.flat<tstring>()(i) = example.SerializeAsString();
1429   }
1430 
1431   const tensorflow::string input_op_name(
1432       tensorflow::ParseTensorName(input_name).first);
1433   TF_Operation* input_op =
1434       TF_GraphOperationByName(graph, input_op_name.c_str());
1435   ASSERT_TRUE(input_op != nullptr);
1436   Status status;
1437   csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
1438   ASSERT_TRUE(status.ok()) << status.error_message();
1439 
1440   const tensorflow::string output_op_name(
1441       tensorflow::ParseTensorName(output_name).first);
1442   TF_Operation* output_op =
1443       TF_GraphOperationByName(graph, output_op_name.c_str());
1444   ASSERT_TRUE(output_op != nullptr);
1445   csession.SetOutputs({output_op});
1446   csession.Run(s);
1447   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1448 
1449   TF_Tensor* out = csession.output_tensor(0);
1450   ASSERT_TRUE(out != nullptr);
1451   EXPECT_EQ(TF_FLOAT, TF_TensorType(out));
1452   EXPECT_EQ(2, TF_NumDims(out));
1453   EXPECT_EQ(4, TF_Dim(out, 0));
1454   EXPECT_EQ(1, TF_Dim(out, 1));
1455   float* values = static_cast<float*>(TF_TensorData(out));
1456   // These values are defined to be (input / 2) + 2.
1457   EXPECT_EQ(2, values[0]);
1458   EXPECT_EQ(2.5, values[1]);
1459   EXPECT_EQ(3, values[2]);
1460   EXPECT_EQ(3.5, values[3]);
1461 
1462   csession.CloseAndDelete(s);
1463   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1464   TF_DeleteGraph(graph);
1465   TF_DeleteStatus(s);
1466 }
1467 
TEST(CAPI,SavedModelNullArgsAreValid)1468 TEST(CAPI, SavedModelNullArgsAreValid) {
1469   const string saved_model_dir = tensorflow::GetDataDependencyFilepath(
1470       tensorflow::io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
1471                                "half_plus_two", "00000123"));
1472   TF_SessionOptions* opt = TF_NewSessionOptions();
1473   TF_Status* s = TF_NewStatus();
1474   const char* tags[] = {tensorflow::kSavedModelTagServe};
1475   TF_Graph* graph = TF_NewGraph();
1476   // NULL run_options and meta_graph_def should work.
1477   TF_Session* session = TF_LoadSessionFromSavedModel(
1478       opt, nullptr, saved_model_dir.c_str(), tags, 1, graph, nullptr, s);
1479   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1480   TF_DeleteSessionOptions(opt);
1481   TF_CloseSession(session, s);
1482   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1483   TF_DeleteSession(session, s);
1484   EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1485   TF_DeleteGraph(graph);
1486   TF_DeleteStatus(s);
1487 }
1488 
TEST(CAPI,DeletingNullPointerIsSafe)1489 TEST(CAPI, DeletingNullPointerIsSafe) {
1490   TF_Status* status = TF_NewStatus();
1491 
1492   TF_DeleteStatus(nullptr);
1493   TF_DeleteBuffer(nullptr);
1494   TF_DeleteTensor(nullptr);
1495   TF_DeleteSessionOptions(nullptr);
1496   TF_DeleteGraph(nullptr);
1497   TF_DeleteImportGraphDefOptions(nullptr);
1498   TF_DeleteImportGraphDefResults(nullptr);
1499   TF_DeleteFunction(nullptr);
1500   TF_DeleteSession(nullptr, status);
1501   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1502   TF_DeletePRunHandle(nullptr);
1503   TF_DeleteDeprecatedSession(nullptr, status);
1504   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1505   TF_DeleteDeviceList(nullptr);
1506   TF_DeleteLibraryHandle(nullptr);
1507   TF_DeleteApiDefMap(nullptr);
1508 
1509   TF_DeleteStatus(status);
1510 }
1511 
TEST(CAPI,TestBitcastFrom_Reshape)1512 TEST(CAPI, TestBitcastFrom_Reshape) {
1513   int64_t dims[] = {2, 3};
1514   TF_Tensor* a =
1515       TF_AllocateTensor(TF_UINT64, dims, 2, 6 * TF_DataTypeSize(TF_UINT64));
1516   TF_Tensor* b =
1517       TF_AllocateTensor(TF_UINT64, nullptr, 0, TF_DataTypeSize(TF_UINT64));
1518   EXPECT_NE(a, nullptr);
1519   EXPECT_NE(b, nullptr);
1520 
1521   EXPECT_EQ(6, TF_TensorElementCount(a));
1522   EXPECT_EQ(1, TF_TensorElementCount(b));
1523   EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a));
1524   EXPECT_EQ(TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b));
1525 
1526   int64_t new_dims[] = {3, 2};
1527   TF_Status* status = TF_NewStatus();
1528   TF_TensorBitcastFrom(a, TF_UINT64, b, new_dims, 2, status);
1529   ASSERT_EQ(TF_OK, TF_GetCode(status));
1530   TF_DeleteStatus(status);
1531 
1532   EXPECT_EQ(6, TF_TensorElementCount(a));
1533   EXPECT_EQ(6, TF_TensorElementCount(b));
1534   EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(a));
1535   EXPECT_EQ(6 * TF_DataTypeSize(TF_UINT64), TF_TensorByteSize(b));
1536 
1537   // Check that a write to one tensor shows up in the other.
1538   *(static_cast<int64_t*>(TF_TensorData(a))) = 4;
1539   EXPECT_EQ(4, *(static_cast<int64_t*>(TF_TensorData(b))));
1540   *(static_cast<int64_t*>(TF_TensorData(b))) = 6;
1541   EXPECT_EQ(6, *(static_cast<int64_t*>(TF_TensorData(a))));
1542 
1543   TF_DeleteTensor(a);
1544   TF_DeleteTensor(b);
1545 }
1546 
TEST(CAPI,TestFromProto)1547 TEST(CAPI, TestFromProto) {
1548   Tensor t_cc(DT_FLOAT, TensorShape({2, 3}));
1549   t_cc.flat<float>().setConstant(1.0);
1550   tensorflow::TensorProto t_proto;
1551   t_cc.AsProtoField(&t_proto);
1552 
1553   TF_Buffer* t_buffer = TF_NewBuffer();
1554   TF_CHECK_OK(MessageToBuffer(t_proto, t_buffer));
1555 
1556   const int num_bytes = 6 * sizeof(float);
1557   float* values =
1558       reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
1559           EIGEN_MAX_ALIGN_BYTES, num_bytes));
1560   int64_t dims[] = {2, 3};
1561   bool deallocator_called = false;
1562   TF_Tensor* t_c = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes,
1563                                 &Deallocator, &deallocator_called);
1564 
1565   TF_Status* status = TF_NewStatus();
1566   TF_TensorFromProto(t_buffer, t_c, status);
1567   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
1568 
1569   EXPECT_EQ(1.0, *(static_cast<float*>(TF_TensorData(t_c))));
1570   TF_DeleteStatus(status);
1571   TF_DeleteTensor(t_c);
1572   TF_DeleteBuffer(t_buffer);
1573 }
1574 
1575 REGISTER_OP("TestOpWithNoGradient")
1576     .Input("x: T")
1577     .Output("y: T")
1578     .Attr("T: {float, double}")
1579     .Doc(R"doc(
1580 Test op with no grad registered.
1581 
1582 x: input
1583 y: output
1584 )doc")
1585     .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1586 
1587 class CApiGradientsTest : public ::testing::Test {
1588  protected:
CApiGradientsTest()1589   CApiGradientsTest()
1590       : s_(TF_NewStatus()),
1591         graph_(TF_NewGraph()),
1592         expected_graph_(TF_NewGraph()) {}
1593 
~CApiGradientsTest()1594   ~CApiGradientsTest() override {
1595     TF_DeleteGraph(graph_);
1596     TF_DeleteGraph(expected_graph_);
1597     TF_DeleteStatus(s_);
1598   }
1599 
TestGradientsSuccess(bool grad_inputs_provided)1600   void TestGradientsSuccess(bool grad_inputs_provided) {
1601     TF_Output inputs[2];
1602     TF_Output outputs[1];
1603     TF_Output grad_outputs[2];
1604     TF_Output expected_grad_outputs[2];
1605 
1606     BuildSuccessGraph(inputs, outputs);
1607     BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
1608 
1609     AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1,
1610                  grad_outputs);
1611     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1612 
1613     // Compare that the graphs match.
1614     GraphDef expected_gdef;
1615     GraphDef gdef;
1616     EXPECT_TRUE(GetGraphDef(expected_graph_, &expected_gdef));
1617     EXPECT_TRUE(GetGraphDef(graph_, &gdef));
1618     TF_EXPECT_GRAPH_EQ(expected_gdef, gdef);
1619 
1620     // Compare that the output of the gradients of both graphs match.
1621     RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs);
1622   }
1623 
TestGradientsError(bool grad_inputs_provided)1624   void TestGradientsError(bool grad_inputs_provided) {
1625     TF_Output inputs[1];
1626     TF_Output outputs[1];
1627     TF_Output grad_outputs[1];
1628 
1629     BuildErrorGraph(inputs, outputs);
1630 
1631     AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1,
1632                  grad_outputs);
1633 
1634     string expected_msg =
1635         "No gradient defined for op: TestOpWithNoGradient. Please see "
1636         "https://www.tensorflow.org/code/"
1637         "tensorflow/cc/gradients/README.md"
1638         " for instructions on how to add C++ gradients.";
1639     EXPECT_EQ(expected_msg, TF_Message(s_));
1640   }
1641 
1642   // Run the graph and ensure that the gradient values are as expected.
RunGraphsAndCompareOutputs(TF_Output * grad_outputs,TF_Output * expected_grad_outputs)1643   void RunGraphsAndCompareOutputs(TF_Output* grad_outputs,
1644                                   TF_Output* expected_grad_outputs) {
1645     std::unique_ptr<CSession> csession(new CSession(graph_, s_));
1646     std::unique_ptr<CSession> expected_csession(
1647         new CSession(expected_graph_, s_));
1648 
1649     std::vector<TF_Output> grad_outputs_vec;
1650     grad_outputs_vec.assign(grad_outputs, grad_outputs + 2);
1651     csession->SetOutputs(grad_outputs_vec);
1652     csession->Run(s_);
1653     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1654     TF_Tensor* out0 = csession->output_tensor(0);
1655     TF_Tensor* out1 = csession->output_tensor(1);
1656 
1657     std::vector<TF_Output> expected_grad_outputs_vec;
1658     expected_grad_outputs_vec.assign(expected_grad_outputs,
1659                                      expected_grad_outputs + 2);
1660     expected_csession->SetOutputs(expected_grad_outputs_vec);
1661     expected_csession->Run(s_);
1662     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1663     TF_Tensor* expected_out0 = expected_csession->output_tensor(0);
1664     TF_Tensor* expected_out1 = expected_csession->output_tensor(1);
1665 
1666     CompareTensors(out0, expected_out0);
1667     CompareTensors(out1, expected_out1);
1668   }
1669 
CompareTensors(TF_Tensor * a,TF_Tensor * b)1670   void CompareTensors(TF_Tensor* a, TF_Tensor* b) {
1671     float* a_data = static_cast<float*>(TF_TensorData(a));
1672     float* b_data = static_cast<float*>(TF_TensorData(b));
1673     EXPECT_EQ(*a_data, *b_data);
1674   }
1675 
AddGradients(bool grad_inputs_provided,const char * prefix,TF_Output * inputs,int ninputs,TF_Output * outputs,int noutputs,TF_Output * grad_outputs)1676   void AddGradients(bool grad_inputs_provided, const char* prefix,
1677                     TF_Output* inputs, int ninputs, TF_Output* outputs,
1678                     int noutputs, TF_Output* grad_outputs) {
1679     if (grad_inputs_provided) {
1680       TF_Output grad_inputs[1];
1681       const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0};
1682       TF_Operation* grad_inputs_op =
1683           FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
1684       grad_inputs[0] = TF_Output{grad_inputs_op, 0};
1685       TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
1686                                 ninputs, grad_inputs, s_, grad_outputs);
1687     } else {
1688       TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
1689                                 ninputs, nullptr, s_, grad_outputs);
1690     }
1691   }
1692 
BuildErrorGraph(TF_Output * inputs,TF_Output * outputs)1693   void BuildErrorGraph(TF_Output* inputs, TF_Output* outputs) {
1694     const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1695     TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
1696     TF_Operation* nograd = NoGradientOp(graph_, s_, const0, "NoGrad");
1697     inputs[0] = TF_Output{const0, 0};
1698     outputs[0] = TF_Output{nograd, 0};
1699     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1700   }
1701 
BuildSuccessGraph(TF_Output * inputs,TF_Output * outputs)1702   void BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {
1703     // Construct the following graph:
1704     //            |
1705     //           z|
1706     //            |
1707     //          MatMul
1708     //         /       \
1709     //        ^         ^
1710     //        |         |
1711     //       x|        y|
1712     //        |         |
1713     //        |         |
1714     //      Const_0    Const_1
1715     //
1716     const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1717     const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
1718     TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
1719     TF_Operation* const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1");
1720     TF_Operation* matmul = MatMul(graph_, s_, const0, const1, "MatMul");
1721     inputs[0] = TF_Output{const0, 0};
1722     inputs[1] = TF_Output{const1, 0};
1723     outputs[0] = TF_Output{matmul, 0};
1724     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1725   }
1726 
BuildExpectedGraph(bool grad_inputs_provided,TF_Output * expected_grad_outputs)1727   void BuildExpectedGraph(bool grad_inputs_provided,
1728                           TF_Output* expected_grad_outputs) {
1729     // The expected graph looks like this if grad_inputs_provided.
1730     // If grad_inputs_provided is false, Const_0 will be a OnesLike op.
1731     //      ^             ^
1732     //    dy|           dx|        // MatMul Gradient Graph
1733     //      |             |
1734     //   MatMul_2      MatMul_1
1735     //   ^   ^          ^    ^
1736     //   |   |----------|    |
1737     //   |        ^          |
1738     //   |      dz|          |
1739     //   |        |          |
1740     //   |     Const_3       |
1741     //   |                   |
1742     //   |        ^          |
1743     //   |       z|          |     // MatMul Forward Graph
1744     //   |        |          |
1745     //   |      MatMul       |
1746     //   |     /       \     |
1747     //   |    ^         ^    |
1748     //   |    |         |    |
1749     //   |---x|        y|----|
1750     //        |         |
1751     //        |         |
1752     //      Const_0   Const_1
1753     //
1754     const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
1755     const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
1756     TF_Operation* const0 =
1757         FloatConst2x2(expected_graph_, s_, const0_val, "Const_0");
1758     TF_Operation* const1 =
1759         FloatConst2x2(expected_graph_, s_, const1_val, "Const_1");
1760     TF_Operation* matmul =
1761         MatMul(expected_graph_, s_, const0, const1, "MatMul");
1762 
1763     TF_Operation* const3;
1764     if (grad_inputs_provided) {
1765       const float const3_val[] = {1.0, 1.0, 1.0, 1.0};
1766       const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs");
1767     } else {
1768       const3 = OnesLike(expected_graph_, s_, matmul, "gradients/OnesLike");
1769     }
1770 
1771     TF_Operation* matmul1 = MatMul(expected_graph_, s_, const3, const1,
1772                                    "gradients/MatMul", false, true);
1773     TF_Operation* matmul2 = MatMul(expected_graph_, s_, const0, const3,
1774                                    "gradients/MatMul_1", true, false);
1775     expected_grad_outputs[0] = {matmul1, 0};
1776     expected_grad_outputs[1] = {matmul2, 0};
1777   }
1778 
FloatTensor2x2(const float * values)1779   TF_Tensor* FloatTensor2x2(const float* values) {
1780     const int64_t dims[2] = {2, 2};
1781     TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
1782     memcpy(TF_TensorData(t), values, sizeof(float) * 4);
1783     return t;
1784   }
1785 
FloatConst2x2(TF_Graph * graph,TF_Status * s,const float * values,const char * name)1786   TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s,
1787                               const float* values, const char* name) {
1788     unique_tensor_ptr tensor(FloatTensor2x2(values), TF_DeleteTensor);
1789     TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
1790     TF_SetAttrTensor(desc, "value", tensor.get(), s);
1791     if (TF_GetCode(s) != TF_OK) return nullptr;
1792     TF_SetAttrType(desc, "dtype", TF_FLOAT);
1793     TF_Operation* op = TF_FinishOperation(desc, s);
1794     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1795     return op;
1796   }
1797 
MatMul(TF_Graph * graph,TF_Status * s,TF_Operation * l,TF_Operation * r,const char * name,bool transpose_a=false,bool transpose_b=false)1798   TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l,
1799                        TF_Operation* r, const char* name,
1800                        bool transpose_a = false, bool transpose_b = false) {
1801     TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);
1802     if (transpose_a) {
1803       TF_SetAttrBool(desc, "transpose_a", 1);
1804     }
1805     if (transpose_b) {
1806       TF_SetAttrBool(desc, "transpose_b", 1);
1807     }
1808     TF_AddInput(desc, {l, 0});
1809     TF_AddInput(desc, {r, 0});
1810     TF_Operation* op = TF_FinishOperation(desc, s);
1811     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1812     return op;
1813   }
1814 
OnesLike(TF_Graph * graph,TF_Status * s,TF_Operation * in,const char * name)1815   TF_Operation* OnesLike(TF_Graph* graph, TF_Status* s, TF_Operation* in,
1816                          const char* name) {
1817     TF_OperationDescription* desc = TF_NewOperation(graph, "OnesLike", name);
1818     TF_AddInput(desc, {in, 0});
1819     TF_Operation* op = TF_FinishOperation(desc, s);
1820     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1821     return op;
1822   }
1823 
NoGradientOp(TF_Graph * graph,TF_Status * s,TF_Operation * in,const char * name)1824   TF_Operation* NoGradientOp(TF_Graph* graph, TF_Status* s, TF_Operation* in,
1825                              const char* name) {
1826     TF_OperationDescription* desc =
1827         TF_NewOperation(graph, "TestOpWithNoGradient", name);
1828     TF_AddInput(desc, {in, 0});
1829     TF_Operation* op = TF_FinishOperation(desc, s);
1830     EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1831     return op;
1832   }
1833 
BuildGraphAndAddGradientsWithPrefixes(const char * prefix1,const char * prefix2=nullptr)1834   void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1,
1835                                              const char* prefix2 = nullptr) {
1836     TF_Output inputs[2];
1837     TF_Output outputs[1];
1838     TF_Output grad_outputs[2];
1839 
1840     BuildSuccessGraph(inputs, outputs);
1841 
1842     AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs);
1843     if (prefix2 != nullptr) {
1844       AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs);
1845     }
1846   }
1847 
1848   TF_Status* s_;
1849   TF_Graph* graph_;
1850   TF_Graph* expected_graph_;
1851 };
1852 
TEST_F(CApiGradientsTest,Gradients_GradInputs)1853 TEST_F(CApiGradientsTest, Gradients_GradInputs) { TestGradientsSuccess(true); }
1854 
TEST_F(CApiGradientsTest,Gradients_NoGradInputs)1855 TEST_F(CApiGradientsTest, Gradients_NoGradInputs) {
1856   TestGradientsSuccess(false);
1857 }
1858 
TEST_F(CApiGradientsTest,OpWithNoGradientRegistered_GradInputs)1859 TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) {
1860   TestGradientsError(true);
1861 }
1862 
TEST_F(CApiGradientsTest,OpWithNoGradientRegistered_NoGradInputs)1863 TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
1864   TestGradientsError(false);
1865 }
1866 
TEST_F(CApiGradientsTest,GradientsPrefix_PrefixIsOk)1867 TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) {
1868   BuildGraphAndAddGradientsWithPrefixes("gradients");
1869   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1870 }
1871 
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsWithDistinctPrefixes)1872 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) {
1873   BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1");
1874   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1875 }
1876 
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsInSameScope)1877 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) {
1878   BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1");
1879   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1880 }
1881 
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsInDifferentScopes)1882 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) {
1883   BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients");
1884   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1885 }
1886 
TEST_F(CApiGradientsTest,GradientsPrefix_2ndGradientsAsSubScopeOf1st)1887 TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) {
1888   BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub");
1889   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1890 }
1891 
TEST_F(CApiGradientsTest,GradientsPrefix_PrefixMatchesExistingNodeName)1892 TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) {
1893   BuildGraphAndAddGradientsWithPrefixes("Const_0");
1894   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1895 }
1896 
TEST_F(CApiGradientsTest,GradientsPrefix_TwoGradientsWithIdenticalPrefixes)1897 TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) {
1898   BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients");
1899   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1900 }
1901 
TEST_F(CApiGradientsTest,GradientsPrefix_2ndGradientsMatchingNodeOf1st)1902 TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) {
1903   BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul");
1904   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1905 }
1906 
TEST_F(CApiGradientsTest,GradientsPrefix_1stGradientsMatchingNodeOf2nd)1907 TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) {
1908   BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients");
1909   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1910 }
1911 
TEST_F(CApiGradientsTest,GradientsPrefix_2ndGradientsAsParentScopeOf1st)1912 TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) {
1913   BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients");
1914   ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
1915 }
1916 
ScalarFloatFromTensor(const TF_Tensor * t,float * f)1917 void ScalarFloatFromTensor(const TF_Tensor* t, float* f) {
1918   ASSERT_TRUE(t != nullptr);
1919   ASSERT_EQ(TF_FLOAT, TF_TensorType(t));
1920   ASSERT_EQ(0, TF_NumDims(t));
1921   ASSERT_EQ(4, TF_TensorByteSize(t));
1922   float* p = static_cast<float*>(TF_TensorData(t));
1923   *f = *p;
1924 }
1925 
TEST_F(CApiGradientsTest,MultipleCallsToAddGradients)1926 TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) {
1927   const float X = 3.0f, Y = 7.0f;
1928   TF_Operation* x = Placeholder(graph_, s_, "x", TF_FLOAT);
1929   TF_Operation* y = Placeholder(graph_, s_, "y", TF_FLOAT);
1930   TF_Operation* xy = Mul(x, y, graph_, s_, "xy");
1931   TF_Output dxy_dx, dxy_dy;
1932 
1933   TF_Output outputs[1] = {{xy, 0}};
1934   TF_Output inputs[1] = {{x, 0}};
1935   TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx);
1936   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1937 
1938   inputs[0] = {y, 0};
1939   TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy);
1940   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1941 
1942   TF_SessionOptions* opts = TF_NewSessionOptions();
1943   TF_Session* sess = TF_NewSession(graph_, opts, s_);
1944   TF_DeleteSessionOptions(opts);
1945   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1946 
1947   TF_Output feeds[] = {{x, 0}, {y, 0}};
1948   TF_Tensor* feedValues[] = {FloatTensor(X), FloatTensor(Y)};
1949   TF_Output fetches[] = {dxy_dx, dxy_dy};
1950   TF_Tensor* fetchValues[] = {nullptr, nullptr};
1951 
1952   TF_SessionRun(sess, nullptr /* run_options */, feeds, feedValues, 2, fetches,
1953                 fetchValues, 2, nullptr /* target_opers */, 0,
1954                 nullptr /* run_metadata */, s_);
1955   TF_DeleteTensor(feedValues[0]);
1956   TF_DeleteTensor(feedValues[1]);
1957   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1958   TF_DeleteSession(sess, s_);
1959   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1960 
1961   float dxy_dxValue = 0.0f, dxy_dyValue = 0.0f;
1962   ScalarFloatFromTensor(fetchValues[0], &dxy_dxValue);
1963   EXPECT_EQ(Y, dxy_dxValue);
1964 
1965   ScalarFloatFromTensor(fetchValues[1], &dxy_dyValue);
1966   EXPECT_EQ(X, dxy_dyValue);
1967 
1968   TF_DeleteTensor(fetchValues[0]);
1969   TF_DeleteTensor(fetchValues[1]);
1970 }
1971 
1972 // REGISTER_OP for CApiAttributesTest test cases.
1973 // Registers two ops, each with a single attribute called 'v'.
1974 // The attribute in one op will have a type 'type', the other
1975 // will have list(type).
1976 #define ATTR_TEST_REGISTER_OP(type)                           \
1977   REGISTER_OP("CApiAttributesTestOp" #type)                   \
1978       .Attr("v: " #type)                                      \
1979       .SetShapeFn(tensorflow::shape_inference::UnknownShape); \
1980   REGISTER_OP("CApiAttributesTestOpList" #type)               \
1981       .Attr("v: list(" #type ")")                             \
1982       .SetShapeFn(tensorflow::shape_inference::UnknownShape)
1983 ATTR_TEST_REGISTER_OP(string);
1984 ATTR_TEST_REGISTER_OP(int);
1985 ATTR_TEST_REGISTER_OP(float);
1986 ATTR_TEST_REGISTER_OP(bool);
1987 ATTR_TEST_REGISTER_OP(type);
1988 ATTR_TEST_REGISTER_OP(shape);
1989 ATTR_TEST_REGISTER_OP(tensor);
1990 #undef ATTR_TEST_REGISTER_OP
1991 
1992 class CApiAttributesTest : public ::testing::Test {
1993  protected:
CApiAttributesTest()1994   CApiAttributesTest()
1995       : s_(TF_NewStatus()), graph_(TF_NewGraph()), counter_(0) {}
1996 
~CApiAttributesTest()1997   ~CApiAttributesTest() override {
1998     TF_DeleteGraph(graph_);
1999     TF_DeleteStatus(s_);
2000   }
2001 
init(string type)2002   TF_OperationDescription* init(string type) {
2003     // Construct op_name to match the name used by REGISTER_OP in the
2004     // ATTR_TEST_REGISTER calls above.
2005     string op_name = "CApiAttributesTestOp";
2006     if (type.find("list(") == 0) {
2007       op_name += "List";
2008       type = type.replace(0, 5, "");
2009       type = type.replace(type.size() - 1, 1, "");
2010     }
2011     op_name += type;
2012     return TF_NewOperation(
2013         graph_, op_name.c_str(),
2014         ::tensorflow::strings::StrCat("name", counter_++).c_str());
2015   }
2016 
2017   TF_Status* s_;
2018 
2019  private:
2020   TF_Graph* graph_;
2021   int counter_;
2022 };
2023 
2024 // Helper macros for the TF_OperationGetAttr* tests.
2025 // TODO(ashankar): Use gmock matchers instead?
2026 // (https://github.com/google/googletest/blob/master/googlemock/docs/CookBook.md#writing-new-parameterized-matchers-quickly)
2027 // That will require setting up the tensorflow build with gmock.
2028 #define EXPECT_TF_META(attr_name, expected_list_size, expected_type, \
2029                        expected_total_size)                          \
2030   do {                                                               \
2031     auto m = TF_OperationGetAttrMetadata(oper, attr_name, s_);       \
2032     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);              \
2033     const unsigned char e = expected_list_size >= 0 ? 1 : 0;         \
2034     EXPECT_EQ(e, m.is_list);                                         \
2035     EXPECT_EQ(expected_list_size, m.list_size);                      \
2036     EXPECT_EQ(expected_type, m.type);                                \
2037     EXPECT_EQ(expected_total_size, m.total_size);                    \
2038   } while (0)
2039 
TEST_F(CApiAttributesTest,String)2040 TEST_F(CApiAttributesTest, String) {
2041   auto desc = init("string");
2042   TF_SetAttrString(desc, "v", "bunny", 5);
2043 
2044   auto oper = TF_FinishOperation(desc, s_);
2045   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2046   EXPECT_TF_META("v", -1, TF_ATTR_STRING, 5);
2047   std::unique_ptr<char[]> value(new char[5]);
2048 
2049   TF_OperationGetAttrString(oper, "v", value.get(), 5, s_);
2050   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2051   EXPECT_EQ("bunny", string(static_cast<const char*>(value.get()), 5));
2052 }
2053 
TEST_F(CApiAttributesTest,StringList)2054 TEST_F(CApiAttributesTest, StringList) {
2055   std::vector<string> list = {"bugs", "bunny", "duck"};
2056   std::unique_ptr<const void*[]> list_ptrs;
2057   std::unique_ptr<size_t[]> list_lens;
2058   StringVectorToArrays(list, &list_ptrs, &list_lens);
2059   int list_total_size = 0;
2060   for (const auto& s : list) {
2061     list_total_size += s.size();
2062   }
2063 
2064   auto desc = init("list(string)");
2065   TF_SetAttrStringList(desc, "v", list_ptrs.get(), list_lens.get(),
2066                        list.size());
2067 
2068   auto oper = TF_FinishOperation(desc, s_);
2069   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2070 
2071   EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size);
2072   std::unique_ptr<void*[]> values(new void*[list.size()]);
2073   std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
2074   std::unique_ptr<char[]> storage(new char[list_total_size]);
2075   TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(),
2076                                 list.size(), storage.get(), list_total_size,
2077                                 s_);
2078   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2079   for (size_t i = 0; i < list.size(); ++i) {
2080     EXPECT_EQ(list[i].size(), lens[i]) << i;
2081     EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
2082         << i;
2083   }
2084 }
2085 
TEST_F(CApiAttributesTest,Int)2086 TEST_F(CApiAttributesTest, Int) {
2087   auto desc = init("int");
2088   TF_SetAttrInt(desc, "v", 31415);
2089 
2090   auto oper = TF_FinishOperation(desc, s_);
2091   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2092   EXPECT_TF_META("v", -1, TF_ATTR_INT, -1);
2093 
2094   int64_t value;
2095   TF_OperationGetAttrInt(oper, "v", &value, s_);
2096   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2097   EXPECT_EQ(31415, value);
2098 }
2099 
TEST_F(CApiAttributesTest,IntList)2100 TEST_F(CApiAttributesTest, IntList) {
2101   const int64_t list[] = {1, 2, 3, 4};
2102   const size_t list_size = TF_ARRAYSIZE(list);
2103 
2104   auto desc = init("list(int)");
2105   TF_SetAttrIntList(desc, "v", list, list_size);
2106 
2107   auto oper = TF_FinishOperation(desc, s_);
2108   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2109 
2110   int64_t values[list_size];
2111   EXPECT_TF_META("v", list_size, TF_ATTR_INT, -1);
2112   TF_OperationGetAttrIntList(oper, "v", values, list_size, s_);
2113   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2114   EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2115 }
2116 
TEST_F(CApiAttributesTest,Float)2117 TEST_F(CApiAttributesTest, Float) {
2118   auto desc = init("float");
2119   TF_SetAttrFloat(desc, "v", 2.718);
2120 
2121   auto oper = TF_FinishOperation(desc, s_);
2122   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2123   EXPECT_TF_META("v", -1, TF_ATTR_FLOAT, -1);
2124 
2125   float value;
2126   TF_OperationGetAttrFloat(oper, "v", &value, s_);
2127   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2128   EXPECT_FLOAT_EQ(2.718, value);
2129 }
2130 
TEST_F(CApiAttributesTest,FloatList)2131 TEST_F(CApiAttributesTest, FloatList) {
2132   const float list[] = {1.414, 2.718, 3.1415};
2133   const size_t list_size = TF_ARRAYSIZE(list);
2134 
2135   auto desc = init("list(float)");
2136   TF_SetAttrFloatList(desc, "v", list, list_size);
2137 
2138   auto oper = TF_FinishOperation(desc, s_);
2139   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2140 
2141   float values[list_size];
2142   EXPECT_TF_META("v", list_size, TF_ATTR_FLOAT, -1);
2143   TF_OperationGetAttrFloatList(oper, "v", values, list_size, s_);
2144   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2145   EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2146 }
2147 
TEST_F(CApiAttributesTest,Bool)2148 TEST_F(CApiAttributesTest, Bool) {
2149   auto desc = init("bool");
2150   TF_SetAttrBool(desc, "v", 1);
2151 
2152   auto oper = TF_FinishOperation(desc, s_);
2153   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2154   EXPECT_TF_META("v", -1, TF_ATTR_BOOL, -1);
2155 
2156   unsigned char value;
2157   TF_OperationGetAttrBool(oper, "v", &value, s_);
2158   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2159   EXPECT_EQ(1, value);
2160 }
2161 
TEST_F(CApiAttributesTest,BoolList)2162 TEST_F(CApiAttributesTest, BoolList) {
2163   const unsigned char list[] = {0, 1, 1, 0, 0, 1, 1};
2164   const size_t list_size = TF_ARRAYSIZE(list);
2165 
2166   auto desc = init("list(bool)");
2167   TF_SetAttrBoolList(desc, "v", list, list_size);
2168 
2169   auto oper = TF_FinishOperation(desc, s_);
2170   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2171 
2172   unsigned char values[list_size];
2173   EXPECT_TF_META("v", list_size, TF_ATTR_BOOL, -1);
2174   TF_OperationGetAttrBoolList(oper, "v", values, list_size, s_);
2175   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2176   EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2177 }
2178 
TEST_F(CApiAttributesTest,Type)2179 TEST_F(CApiAttributesTest, Type) {
2180   auto desc = init("type");
2181   TF_SetAttrType(desc, "v", TF_COMPLEX128);
2182 
2183   auto oper = TF_FinishOperation(desc, s_);
2184   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2185   EXPECT_TF_META("v", -1, TF_ATTR_TYPE, -1);
2186 
2187   TF_DataType value;
2188   TF_OperationGetAttrType(oper, "v", &value, s_);
2189   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2190   EXPECT_EQ(TF_COMPLEX128, value);
2191 }
2192 
TEST_F(CApiAttributesTest,TypeList)2193 TEST_F(CApiAttributesTest, TypeList) {
2194   const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
2195   const size_t list_size = TF_ARRAYSIZE(list);
2196 
2197   auto desc = init("list(type)");
2198   TF_SetAttrTypeList(desc, "v", list, list_size);
2199 
2200   auto oper = TF_FinishOperation(desc, s_);
2201   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2202 
2203   TF_DataType values[list_size];
2204   EXPECT_TF_META("v", list_size, TF_ATTR_TYPE, -1);
2205   TF_OperationGetAttrTypeList(oper, "v", values, list_size, s_);
2206   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2207   EXPECT_TRUE(std::equal(std::begin(list), std::end(list), std::begin(values)));
2208 }
2209 
TEST_F(CApiAttributesTest,Shape)2210 TEST_F(CApiAttributesTest, Shape) {
2211   // Unknown shape
2212   auto desc = init("shape");
2213   TF_SetAttrShape(desc, "v", nullptr, -1);
2214   auto oper = TF_FinishOperation(desc, s_);
2215   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2216   EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, -1);
2217   TF_OperationGetAttrShape(oper, "v", nullptr, 10, s_);
2218   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2219 
2220   // Partially specified shape
2221   const int64_t partial_shape[] = {17, -1};
2222   const size_t sz = TF_ARRAYSIZE(partial_shape);
2223   desc = init("shape");
2224   TF_SetAttrShape(desc, "v", partial_shape, sz);
2225   oper = TF_FinishOperation(desc, s_);
2226   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2227   EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, sz);
2228   int64_t values[sz];
2229   TF_OperationGetAttrShape(oper, "v", values, sz, s_);
2230   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2231   EXPECT_TRUE(
2232       std::equal(std::begin(partial_shape), std::end(partial_shape), values));
2233 }
2234 
TEST_F(CApiAttributesTest,ShapeList)2235 TEST_F(CApiAttributesTest, ShapeList) {
2236   const int64_t shape_1[] = {1, 3};
2237   const int64_t shape_2[] = {2, 4, 6};
2238   const int64_t* list[] = {&shape_1[0], &shape_2[0]};
2239   const size_t list_size = TF_ARRAYSIZE(list);
2240   const int ndims[] = {TF_ARRAYSIZE(shape_1), TF_ARRAYSIZE(shape_2)};
2241   const int total_ndims = 5;  // ndims[0] + ndims[1]
2242 
2243   auto desc = init("list(shape)");
2244   TF_SetAttrShapeList(desc, "v", list, ndims, list_size);
2245   auto oper = TF_FinishOperation(desc, s_);
2246   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2247 
2248   EXPECT_TF_META("v", list_size, TF_ATTR_SHAPE, total_ndims);
2249   int64_t* values[list_size];
2250   int values_ndims[list_size];
2251   int64_t storage[total_ndims];
2252   TF_OperationGetAttrShapeList(oper, "v", values, values_ndims, list_size,
2253                                storage, total_ndims, s_);
2254   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2255   for (size_t i = 0; i < list_size; ++i) {
2256     EXPECT_EQ(ndims[i], values_ndims[i]) << i;
2257     for (int j = 0; j < values_ndims[i]; ++j) {
2258       EXPECT_EQ(list[i][j], values[i][j]) << "(" << i << ", " << j << ")";
2259     }
2260   }
2261 }
2262 
TEST_F(CApiAttributesTest,TensorShapeProto)2263 TEST_F(CApiAttributesTest, TensorShapeProto) {
2264   const int64_t pts[] = {2, 4, -1, 8};
2265   tensorflow::TensorShapeProto proto;
2266   tensorflow::PartialTensorShape(pts).AsProto(&proto);
2267   string bytes;
2268   proto.SerializeToString(&bytes);
2269 
2270   auto desc = init("shape");
2271   TF_SetAttrTensorShapeProto(desc, "v", bytes.data(), bytes.length(), s_);
2272   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2273   auto oper = TF_FinishOperation(desc, s_);
2274   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2275 
2276   EXPECT_TF_META("v", -1, TF_ATTR_SHAPE, 4);
2277   TF_Buffer* value = TF_NewBuffer();
2278   TF_OperationGetAttrTensorShapeProto(oper, "v", value, s_);
2279   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2280   EXPECT_EQ(bytes.length(), value->length);
2281   EXPECT_EQ(0, memcmp(bytes.data(), value->data, value->length));
2282   TF_DeleteBuffer(value);
2283 }
2284 
TEST_F(CApiAttributesTest,TensorShapeProtoList)2285 TEST_F(CApiAttributesTest, TensorShapeProtoList) {
2286   string bytes1, bytes2;
2287   tensorflow::TensorShapeProto proto;
2288 
2289   const int64_t pts1[] = {2, 4, -1, 8};
2290   tensorflow::PartialTensorShape(pts1).AsProto(&proto);
2291   proto.SerializeToString(&bytes1);
2292 
2293   const int64_t pts2[] = {1, 3, 5, 7};
2294   tensorflow::PartialTensorShape(pts2).AsProto(&proto);
2295   proto.SerializeToString(&bytes2);
2296 
2297   std::unique_ptr<const void*[]> list_ptrs;
2298   std::unique_ptr<size_t[]> list_lens;
2299   const std::vector<string> list = {bytes1, bytes2};
2300   StringVectorToArrays(list, &list_ptrs, &list_lens);
2301 
2302   auto desc = init("list(shape)");
2303   TF_SetAttrTensorShapeProtoList(desc, "v", list_ptrs.get(), list_lens.get(),
2304                                  list.size(), s_);
2305   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2306   auto oper = TF_FinishOperation(desc, s_);
2307   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2308 
2309   EXPECT_TF_META("v", 2, TF_ATTR_SHAPE, 8);
2310   TF_Buffer* values[2];
2311   TF_OperationGetAttrTensorShapeProtoList(oper, "v", values, 2, s_);
2312   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2313   for (int i = 0; i < 2; ++i) {
2314     int le = list_lens[i];
2315     int la = values[i]->length;
2316     const void* e = list_ptrs[i];
2317     const void* a = values[i]->data;
2318     EXPECT_EQ(le, la) << i;
2319     EXPECT_EQ(0, memcmp(e, a, std::min(le, la))) << i;
2320     TF_DeleteBuffer(values[i]);
2321   }
2322 }
2323 
TEST_F(CApiAttributesTest,Tensor)2324 TEST_F(CApiAttributesTest, Tensor) {
2325   const char tensor[] = {5, 7};
2326   const int64_t dims[] = {1, 2};
2327   const size_t ndims = TF_ARRAYSIZE(dims);
2328 
2329   auto desc = init("tensor");
2330   unique_tensor_ptr v(Int8Tensor(dims, ndims, tensor), TF_DeleteTensor);
2331   TF_SetAttrTensor(desc, "v", v.get(), s_);
2332   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2333 
2334   auto oper = TF_FinishOperation(desc, s_);
2335   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2336 
2337   EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1);
2338   TF_Tensor* value;
2339   TF_OperationGetAttrTensor(oper, "v", &value, s_);
2340   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2341   ASSERT_NE(nullptr, value);
2342   EXPECT_EQ(TF_INT8, TF_TensorType(value));
2343   EXPECT_EQ(ndims, TF_NumDims(value));
2344   for (int i = 0; i < TF_NumDims(value); ++i) {
2345     EXPECT_EQ(dims[i], TF_Dim(value, i)) << i;
2346   }
2347   EXPECT_EQ(sizeof(char) * TF_ARRAYSIZE(tensor), TF_TensorByteSize(value));
2348   EXPECT_EQ(0, memcmp(tensor, TF_TensorData(value), TF_TensorByteSize(value)));
2349   TF_DeleteTensor(value);
2350 }
2351 
TEST_F(CApiAttributesTest,StringTensor)2352 TEST_F(CApiAttributesTest, StringTensor) {
2353   // Create the string-Tensor "attribute" value.
2354   const char test_string[] =
2355       "borkborkborkborkborkborkborkbork";  // >24bytes to force heap alloc
2356   TF_TString tstr[1];
2357   TF_TString_Init(&tstr[0]);
2358   TF_TString_Copy(&tstr[0], test_string, sizeof(test_string) - 1);
2359 
2360   auto deallocator = [](void* data, size_t len, void* arg) {};
2361   unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &tstr[0],
2362                                       sizeof(tstr), deallocator, nullptr),
2363                          TF_DeleteTensor);
2364 
2365   // Create a TF_Operation with the attribute t_in
2366   auto desc = init("tensor");
2367   TF_SetAttrTensor(desc, "v", t_in.get(), s_);
2368   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2369 
2370   auto oper = TF_FinishOperation(desc, s_);
2371   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2372 
2373   // Fetch the attribute back.
2374   EXPECT_TF_META("v", -1, TF_ATTR_TENSOR, -1);
2375   TF_Tensor* t_out = nullptr;
2376   TF_OperationGetAttrTensor(oper, "v", &t_out, s_);
2377   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2378   EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
2379   EXPECT_EQ(0, TF_NumDims(t_out));
2380   ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
2381   TF_TString* t_in_tstr = static_cast<TF_TString*>(TF_TensorData(t_in.get()));
2382   TF_TString* t_out_tstr = static_cast<TF_TString*>(TF_TensorData(t_out));
2383   EXPECT_EQ(absl::string_view(test_string),
2384             absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
2385                               TF_TString_GetSize(t_out_tstr)));
2386   EXPECT_EQ(absl::string_view(TF_TString_GetDataPointer(t_in_tstr),
2387                               TF_TString_GetSize(t_in_tstr)),
2388             absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
2389                               TF_TString_GetSize(t_out_tstr)));
2390   TF_DeleteTensor(t_out);
2391   TF_TString_Dealloc(&tstr[0]);
2392 }
2393 
TEST_F(CApiAttributesTest,TensorList)2394 TEST_F(CApiAttributesTest, TensorList) {
2395   const char tensor1[] = {5, 7};
2396   const int64_t dims1[] = {1, 2};
2397   const size_t ndims1 = TF_ARRAYSIZE(dims1);
2398 
2399   const char tensor2[] = {2, 4, 6, 8};
2400   const int64_t dims2[] = {2, 2};
2401   const size_t ndims2 = TF_ARRAYSIZE(dims2);
2402 
2403   auto desc = init("list(tensor)");
2404   TF_Tensor* tmp[] = {
2405       Int8Tensor(dims1, ndims1, tensor1),
2406       Int8Tensor(dims2, ndims2, tensor2),
2407   };
2408   TF_SetAttrTensorList(desc, "v", tmp, TF_ARRAYSIZE(tmp), s_);
2409   for (int i = 0; i < TF_ARRAYSIZE(tmp); ++i) {
2410     TF_DeleteTensor(tmp[i]);
2411   }
2412   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2413   auto oper = TF_FinishOperation(desc, s_);
2414   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2415 
2416   EXPECT_TF_META("v", 2, TF_ATTR_TENSOR, -1);
2417   TF_Tensor* values[2];
2418   TF_OperationGetAttrTensorList(oper, "v", &values[0], TF_ARRAYSIZE(values),
2419                                 s_);
2420   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2421 
2422   const char* tensor_data[] = {&tensor1[0], &tensor2[0]};
2423   const size_t tensor_size[] = {TF_ARRAYSIZE(tensor1), TF_ARRAYSIZE(tensor2)};
2424   const int64_t* tensor_dims[] = {&dims1[0], &dims2[0]};
2425   const size_t tensor_ndims[] = {ndims1, ndims2};
2426   for (int i = 0; i < 2; ++i) {
2427     TF_Tensor* v = values[i];
2428     ASSERT_NE(nullptr, v) << i;
2429     EXPECT_EQ(TF_INT8, TF_TensorType(v)) << i;
2430     EXPECT_EQ(tensor_ndims[i], TF_NumDims(v)) << i;
2431     for (int j = 0; j < TF_NumDims(v); ++j) {
2432       EXPECT_EQ(tensor_dims[i][j], TF_Dim(v, j))
2433           << "Tensor #" << i << ", dimension #" << j;
2434     }
2435     EXPECT_EQ(sizeof(char) * tensor_size[i], TF_TensorByteSize(v)) << i;
2436     EXPECT_EQ(0,
2437               memcmp(tensor_data[i], TF_TensorData(v), TF_TensorByteSize(v)));
2438     TF_DeleteTensor(v);
2439   }
2440 }
2441 
TEST_F(CApiAttributesTest,EmptyList)2442 TEST_F(CApiAttributesTest, EmptyList) {
2443   auto desc = init("list(int)");
2444   TF_SetAttrIntList(desc, "v", nullptr, 0);
2445   auto oper = TF_FinishOperation(desc, s_);
2446   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2447   EXPECT_TF_META("v", 0, TF_ATTR_INT, -1);
2448 }
2449 
TEST_F(CApiAttributesTest,Names)2450 TEST_F(CApiAttributesTest, Names) {
2451   auto desc = init("string");
2452   TF_SetAttrString(desc, "v", "bunny", 5);
2453 
2454   auto oper = TF_FinishOperation(desc, s_);
2455   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2456   EXPECT_TF_META("v", -1, TF_ATTR_STRING, 5);
2457 
2458   ASSERT_EQ(1, TF_OperationGetNumAttrs(oper));
2459   ASSERT_EQ(1, TF_OperationGetAttrNameLength(oper, 0));
2460 
2461   std::unique_ptr<char[]> value(new char[1]);
2462 
2463   TF_OperationGetAttrName(oper, 0, value.get(), s_);
2464   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2465   EXPECT_EQ("v", string(static_cast<const char*>(value.get()), 1));
2466 }
2467 
TEST_F(CApiAttributesTest,Errors)2468 TEST_F(CApiAttributesTest, Errors) {
2469   auto desc = init("int");
2470   TF_SetAttrInt(desc, "v", 3);
2471   auto oper = TF_FinishOperation(desc, s_);
2472   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
2473   TF_OperationGetAttrString(oper, "v", nullptr, 0, s_);
2474   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
2475 }
2476 
TEST(TestApiDef,TestCreateApiDef)2477 TEST(TestApiDef, TestCreateApiDef) {
2478   // TODO(b/73318067): Fix linking for the GPU test generated by the
2479   // tf_cuda_cc_test() bazel rule and remove the next line.
2480   if (!GPUDeviceName().empty()) return;
2481 
2482   TF_Buffer* op_list_buf = TF_GetAllOpList();
2483   TF_Status* status = TF_NewStatus();
2484   auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
2485   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2486   TF_DeleteStatus(status);
2487 
2488   string op_name = "TestCApi";
2489   status = TF_NewStatus();
2490   auto* api_def_buf =
2491       TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
2492   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2493   TF_DeleteStatus(status);
2494 
2495   tensorflow::ApiDef api_def;
2496   EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
2497   EXPECT_EQ(op_name, api_def.graph_op_name());
2498   EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary());
2499 
2500   TF_DeleteBuffer(api_def_buf);
2501   TF_DeleteApiDefMap(api_def_map);
2502   TF_DeleteBuffer(op_list_buf);
2503 }
2504 
TEST(TestApiDef,TestCreateApiDefWithOverwrites)2505 TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
2506   // TODO(b/73318067): Fix linking for the GPU test generated by the
2507   // tf_cuda_cc_test() bazel rule and remove the next line.
2508   if (!GPUDeviceName().empty()) return;
2509 
2510   TF_Buffer* op_list_buf = TF_GetAllOpList();
2511   TF_Status* status = TF_NewStatus();
2512   auto* api_def_map = TF_NewApiDefMap(op_list_buf, status);
2513   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2514   TF_DeleteStatus(status);
2515 
2516   string api_def_overwrites = R"(op: <
2517   graph_op_name: "TestCApi"
2518   summary: "New summary"
2519 >
2520 )";
2521   status = TF_NewStatus();
2522   TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(),
2523                   api_def_overwrites.size(), status);
2524   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2525   TF_DeleteStatus(status);
2526 
2527   string op_name = "TestCApi";
2528   status = TF_NewStatus();
2529   auto* api_def_buf =
2530       TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status);
2531   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2532   TF_DeleteStatus(status);
2533 
2534   tensorflow::ApiDef api_def;
2535   EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length));
2536   EXPECT_EQ(op_name, api_def.graph_op_name());
2537   EXPECT_EQ("New summary", api_def.summary());
2538 
2539   TF_DeleteBuffer(api_def_buf);
2540   TF_DeleteApiDefMap(api_def_map);
2541   TF_DeleteBuffer(op_list_buf);
2542 }
2543 
2544 class DummyKernel : public tensorflow::OpKernel {
2545  public:
DummyKernel(tensorflow::OpKernelConstruction * context)2546   explicit DummyKernel(tensorflow::OpKernelConstruction* context)
2547       : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)2548   void Compute(tensorflow::OpKernelContext* context) override {}
2549 };
2550 
2551 // Test we can query kernels
2552 REGISTER_OP("TestOpWithSingleKernel")
2553     .Input("a: float")
2554     .Input("b: float")
2555     .Output("o: float");
2556 REGISTER_KERNEL_BUILDER(
2557     Name("TestOpWithSingleKernel").Device(tensorflow::DEVICE_CPU), DummyKernel);
2558 
TEST(TestKernel,TestGetAllRegisteredKernels)2559 TEST(TestKernel, TestGetAllRegisteredKernels) {
2560   TF_Status* status = TF_NewStatus();
2561   TF_Buffer* kernel_list_buf = TF_GetAllRegisteredKernels(status);
2562   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2563   KernelList kernel_list;
2564   kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
2565   ASSERT_GT(kernel_list.kernel_size(), 0);
2566   TF_DeleteBuffer(kernel_list_buf);
2567   TF_DeleteStatus(status);
2568 }
2569 
TEST(TestKernel,TestGetRegisteredKernelsForOp)2570 TEST(TestKernel, TestGetRegisteredKernelsForOp) {
2571   TF_Status* status = TF_NewStatus();
2572   TF_Buffer* kernel_list_buf =
2573       TF_GetRegisteredKernelsForOp("TestOpWithSingleKernel", status);
2574   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2575   KernelList kernel_list;
2576   kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
2577   ASSERT_EQ(kernel_list.kernel_size(), 1);
2578   EXPECT_EQ(kernel_list.kernel(0).op(), "TestOpWithSingleKernel");
2579   EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
2580   TF_DeleteBuffer(kernel_list_buf);
2581   TF_DeleteStatus(status);
2582 }
2583 
TEST(TestKernel,TestGetRegisteredKernelsForOpNoKernels)2584 TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) {
2585   TF_Status* status = TF_NewStatus();
2586   TF_Buffer* kernel_list_buf = TF_GetRegisteredKernelsForOp("Unknown", status);
2587   EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
2588   KernelList kernel_list;
2589   kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
2590   ASSERT_EQ(kernel_list.kernel_size(), 0);
2591   TF_DeleteBuffer(kernel_list_buf);
2592   TF_DeleteStatus(status);
2593 }
2594 
2595 #undef EXPECT_TF_META
2596 
TEST(CAPI,TestTensorAligned)2597 TEST(CAPI, TestTensorAligned) {
2598   int64_t dim = 7;
2599   size_t tensor_size_bytes = dim * TF_DataTypeSize(TF_FLOAT);
2600   TF_Tensor* a = TF_AllocateTensor(
2601       /*dtype=*/TF_FLOAT, /*dims=*/&dim, /*num_dims=*/1,
2602       /*len=*/tensor_size_bytes);
2603   float* data = reinterpret_cast<float*>(TF_TensorData(a));
2604   for (int i = 0; i < dim; ++i) {
2605     data[i] = 0;
2606   }
2607   if (EIGEN_MAX_ALIGN_BYTES > 0) {
2608     EXPECT_TRUE(TF_TensorIsAligned(a));
2609   }
2610   TF_DeleteTensor(a);
2611 }
2612 
TEST(CAPI,TestTensorIsNotAligned)2613 TEST(CAPI, TestTensorIsNotAligned) {
2614   // Test unaligned access via a Slice.
2615   Tensor x(DT_FLOAT, TensorShape({30}));
2616   x.flat<float>().setConstant(0.0);
2617 
2618   // Take an unaligned slice.
2619   Tensor y = x.Slice(1, 13);
2620   Status status;
2621   TF_Tensor* a = TF_TensorFromTensor(y, &status);
2622   if (EIGEN_MAX_ALIGN_BYTES > 0) {
2623     EXPECT_FALSE(TF_TensorIsAligned(a));
2624   }
2625   TF_DeleteTensor(a);
2626 }
2627 
TEST(CAPI,MessageBufferConversion)2628 TEST(CAPI, MessageBufferConversion) {
2629   NodeDef node_in, node_out;
2630   node_in.set_name("Test name");
2631   node_in.set_op("Test op");
2632 
2633   TF_Buffer* buffer = TF_NewBuffer();
2634   TF_CHECK_OK(MessageToBuffer(node_in, buffer));
2635   TF_CHECK_OK(BufferToMessage(buffer, &node_out));
2636   TF_DeleteBuffer(buffer);
2637 
2638   protobuf::util::MessageDifferencer differencer;
2639   EXPECT_TRUE(differencer.Compare(node_in, node_out));
2640 }
2641 
TEST(CAPI,TestTensorNonScalarBytesAllocateDelete)2642 TEST(CAPI, TestTensorNonScalarBytesAllocateDelete) {
2643   const int batch_size = 4;
2644   const int num_dims = 2;
2645   int64_t* dims = new int64_t[num_dims];
2646   int64_t num_elements = 1;
2647   dims[0] = batch_size;
2648   dims[1] = 1;
2649   for (int64_t i = 0; i < num_dims; ++i) {
2650     num_elements *= dims[i];
2651   }
2652   TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
2653                                    sizeof(TF_TString) * num_elements);
2654   delete[] dims;
2655 
2656   TF_TString* data = static_cast<TF_TString*>(TF_TensorData(t));
2657   for (int i = 0; i < batch_size; ++i) {
2658     TF_TString_Init(&data[i]);
2659     // The following input string length is large enough to make sure that
2660     // copy to tstring in large mode.
2661     std::string source =
2662         "This is the " + std::to_string(i + 1) + "th. data element\n";
2663     TF_TString_Copy(&data[i], source.c_str(), source.length());
2664   }
2665 
2666   TF_DeleteTensor(t);
2667 }
2668 
2669 }  // namespace
2670 }  // namespace tensorflow
2671 
2672 // TODO(josh11b): Test:
2673 // * TF_SetDevice(desc, "/job:worker");
2674 // * control inputs / outputs
2675 // * targets
2676 // * TF_DeleteGraph() before TF_DeleteSession()
2677