xref: /aosp_15_r20/external/tensorflow/tensorflow/c/kernels_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
16 #define EIGEN_USE_GPU
17 #endif
18 
19 #include "tensorflow/c/kernels.h"
20 
21 #include <stddef.h>
22 #include <stdint.h>
23 #include <string.h>
24 
25 #include <memory>
26 #include <string>
27 #include <utility>
28 
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/str_format.h"
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "tensorflow/c/c_api.h"
33 #include "tensorflow/c/tf_datatype.h"
34 #include "tensorflow/c/tf_status.h"
35 #include "tensorflow/c/tf_tensor.h"
36 #include "tensorflow/core/common_runtime/device.h"
37 #include "tensorflow/core/common_runtime/device_factory.h"
38 #include "tensorflow/core/framework/allocator.h"
39 #include "tensorflow/core/framework/attr_value.pb.h"
40 #include "tensorflow/core/framework/device_base.h"
41 #include "tensorflow/core/framework/kernel_def.pb.h"
42 #include "tensorflow/core/framework/node_def.pb.h"
43 #include "tensorflow/core/framework/node_def_builder.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor_types.h"
48 #include "tensorflow/core/framework/types.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/kernels/ops_testutil.h"
51 #include "tensorflow/core/lib/core/status_test_util.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/status.h"
54 #include "tensorflow/core/platform/test.h"
55 #include "tensorflow/core/platform/types.h"
56 
57 struct MyCustomKernel {
58   bool created;
59   bool compute_called;
60 };
61 
62 static bool delete_called = false;
63 
MyCreateFunc(TF_OpKernelConstruction * ctx)64 static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
65   struct MyCustomKernel* s = new struct MyCustomKernel;
66   s->created = true;
67   s->compute_called = false;
68 
69   // Exercise attribute reads.
70   TF_DataType type;
71   TF_Status* status = TF_NewStatus();
72   TF_OpKernelConstruction_GetAttrType(ctx, "SomeDataTypeAttr", &type, status);
73   EXPECT_EQ(TF_OK, TF_GetCode(status));
74   EXPECT_EQ(TF_FLOAT, type);
75   TF_DeleteStatus(status);
76 
77   // Exercise kernel NodeDef name read
78   TF_StringView name_string_view = TF_OpKernelConstruction_GetName(ctx);
79   std::string node_name = "SomeNodeName";
80   std::string candidate_node_name =
81       std::string(name_string_view.data, name_string_view.len);
82   EXPECT_EQ(node_name, candidate_node_name);
83   return s;
84 }
85 
MyComputeFunc(void * kernel,TF_OpKernelContext * ctx)86 static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) {
87   struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
88   s->compute_called = true;
89   if (ctx != nullptr) {
90     EXPECT_EQ(43, TF_StepId(ctx));
91   }
92 }
93 
MyDeleteFunc(void * kernel)94 static void MyDeleteFunc(void* kernel) {
95   struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
96   EXPECT_TRUE(s->created);
97   EXPECT_TRUE(s->compute_called);
98   delete_called = true;
99   delete s;
100 }
101 
102 namespace tensorflow {
103 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
104 
GetFakeKernel(const char * device_name,const char * op_name,const char * node_name,Status * status)105 static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
106                                                const char* op_name,
107                                                const char* node_name,
108                                                Status* status) {
109   NodeDef def;
110   def.set_op(op_name);
111   def.set_name(node_name);
112   def.set_device(device_name);
113   def.add_input("input1");
114   def.add_input("input2");
115 
116   AttrValue v;
117   v.set_type(DataType::DT_FLOAT);
118   (*def.mutable_attr())["SomeDataTypeAttr"] = v;
119 
120   return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1,
121                         status);
122 }
123 
GetFakeKernel2(const char * device_name,const char * op_name,const char * node_name,Status * status)124 static std::unique_ptr<OpKernel> GetFakeKernel2(const char* device_name,
125                                                 const char* op_name,
126                                                 const char* node_name,
127                                                 Status* status) {
128   NodeDef def;
129   def.set_op(op_name);
130   def.set_name(node_name);
131   def.set_device(device_name);
132   def.add_input("input1");
133   def.add_input("input2");
134   def.add_input("input3");
135   def.add_input("input3");
136   def.add_input("input3");
137 
138   AttrValue v0;
139   v0.set_type(DataType::DT_INT32);
140   v0.set_i(3);
141   (*def.mutable_attr())["NumInput3"] = v0;
142   AttrValue v1;
143   v1.set_type(DataType::DT_FLOAT);
144   (*def.mutable_attr())["SomeDataTypeAttr"] = v1;
145 
146   return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1,
147                         status);
148 }
149 
150 // Tests registration of a single C kernel and checks that calls through the
151 // C/C++ boundary are being made.
TEST(TestKernel,TestRegisterKernelBuilder)152 TEST(TestKernel, TestRegisterKernelBuilder) {
153   const char* node_name = "SomeNodeName";
154   const char* op_name = "FooOp";
155   const char* device_name = "FakeDeviceName1";
156 
157   REGISTER_OP(op_name)
158       .Input("input1: double")
159       .Input("input2: uint8")
160       .Output("output1: uint8")
161       .Attr("SomeDataTypeAttr: type");
162 
163   TF_KernelBuilder* builder = TF_NewKernelBuilder(
164       op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
165 
166   {
167     TF_Status* status = TF_NewStatus();
168     TF_RegisterKernelBuilder(node_name, builder, status);
169     EXPECT_EQ(TF_OK, TF_GetCode(status));
170     TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
171     EXPECT_EQ(TF_OK, TF_GetCode(status));
172     KernelList list;
173     list.ParseFromArray(buf->data, buf->length);
174     ASSERT_EQ(1, list.kernel_size());
175     ASSERT_EQ(device_name, list.kernel(0).device_type());
176     TF_DeleteBuffer(buf);
177     TF_DeleteStatus(status);
178   }
179 
180   {
181     Status status;
182     std::unique_ptr<OpKernel> kernel =
183         GetFakeKernel(device_name, op_name, node_name, &status);
184     TF_EXPECT_OK(status);
185     ASSERT_NE(nullptr, kernel.get());
186     kernel->Compute(nullptr);
187   }
188 
189   ASSERT_TRUE(delete_called);
190 }
191 
TEST(TestKernel,TF_RegisterKernelBuilderWithKernelDef)192 TEST(TestKernel, TF_RegisterKernelBuilderWithKernelDef) {
193   const char* node_name = "SomeNodeName";
194   const char* op_name = "FooOp1";
195   const char* device_name = "FakeDeviceName2";
196 
197   REGISTER_OP(op_name)
198       .Input("input1: double")
199       .Input("input2: uint8")
200       .Output("output1: uint8")
201       .Attr("SomeDataTypeAttr: type");
202 
203   TF_KernelBuilder* builder = TF_NewKernelBuilder(
204       op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
205 
206   KernelDef kernel_def;
207   kernel_def.set_op(op_name);
208   kernel_def.set_device_type(device_name);
209   std::string kernel_def_str = kernel_def.SerializePartialAsString();
210 
211   {
212     TF_Status* status = TF_NewStatus();
213     TF_RegisterKernelBuilderWithKernelDef(kernel_def_str.data(), node_name,
214                                           builder, status);
215     EXPECT_EQ(TF_OK, TF_GetCode(status));
216     TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
217     EXPECT_EQ(TF_OK, TF_GetCode(status));
218     KernelList list;
219     list.ParseFromArray(buf->data, buf->length);
220     ASSERT_EQ(1, list.kernel_size());
221     ASSERT_EQ(device_name, list.kernel(0).device_type());
222     TF_DeleteBuffer(buf);
223     TF_DeleteStatus(status);
224   }
225 
226   {
227     Status status;
228     std::unique_ptr<OpKernel> kernel =
229         GetFakeKernel(device_name, op_name, node_name, &status);
230     TF_EXPECT_OK(status);
231     ASSERT_NE(nullptr, kernel.get());
232     kernel->Compute(nullptr);
233   }
234 
235   ASSERT_TRUE(delete_called);
236 }
237 
238 // REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
239 // Registers two ops, each with a single attribute called 'Attr'.
240 // The attribute in one op will have a type 'type', the other
241 // will have list(type).
242 #define ATTR_TEST_REGISTER_OP(name, type)                     \
243   REGISTER_OP("TestKernelAttr" #name)                         \
244       .Attr("Attr: " #type)                                   \
245       .SetShapeFn(tensorflow::shape_inference::UnknownShape); \
246   REGISTER_OP("TestKernelAttr" #name "List")                  \
247       .Attr("Attr: list(" #type ")")                          \
248       .SetShapeFn(tensorflow::shape_inference::UnknownShape)
249 ATTR_TEST_REGISTER_OP(String, string);
250 ATTR_TEST_REGISTER_OP(Int, int);
251 ATTR_TEST_REGISTER_OP(Float, float);
252 ATTR_TEST_REGISTER_OP(Bool, bool);
253 ATTR_TEST_REGISTER_OP(Type, type);
254 ATTR_TEST_REGISTER_OP(Tensor, tensor);
255 #undef ATTR_TEST_REGISTER_OP
256 
257 // Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
258 #define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
259   do {                                                                     \
260     int32_t list_size, total_size;                                         \
261     TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size,        \
262                                         &total_size, status);              \
263     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);            \
264     EXPECT_EQ(expected_list_size, list_size);                              \
265     EXPECT_EQ(expected_total_size, total_size);                            \
266   } while (0)
267 
268 typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
269 class TestKernelAttr : public ::testing::Test {
270  public:
TestKernelAttr()271   TestKernelAttr() {}
~TestKernelAttr()272   ~TestKernelAttr() override {}
273 
GetFakeKernelWithAttr(const char * op_name,AttrValue v,Status * status)274   std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
275                                                   AttrValue v, Status* status) {
276     NodeDef def;
277     def.set_op(op_name);
278     def.set_name("FakeNode");
279     def.set_device("FakeDevice");
280     (*def.mutable_attr())["Attr"] = v;
281     return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
282                           status);
283   }
284 
CreateAndCallKernelWithAttr(MyCreateFuncWithAttr MyCreateFuncAttr,const char * op_name,AttrValue & v)285   void CreateAndCallKernelWithAttr(MyCreateFuncWithAttr MyCreateFuncAttr,
286                                    const char* op_name, AttrValue& v) {
287     TF_KernelBuilder* builder = TF_NewKernelBuilder(
288         op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
289     {
290       TF_Status* status = TF_NewStatus();
291       TF_RegisterKernelBuilder("FakeNode", builder, status);
292       EXPECT_EQ(TF_OK, TF_GetCode(status));
293       TF_DeleteStatus(status);
294     }
295     Status status;
296     std::unique_ptr<OpKernel> kernel =
297         GetFakeKernelWithAttr(op_name, v, &status);
298     TF_EXPECT_OK(status);
299     ASSERT_NE(nullptr, kernel.get());
300     kernel->Compute(nullptr);
301 
302     ASSERT_TRUE(delete_called);
303   }
304 };
305 
TEST_F(TestKernelAttr,GetNodeDef)306 TEST_F(TestKernelAttr, GetNodeDef) {
307   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
308     struct MyCustomKernel* s = new struct MyCustomKernel;
309     s->created = true;
310     s->compute_called = false;
311 
312     TF_Status* status = TF_NewStatus();
313     TF_Buffer* node_def_buf = TF_OpKernelConstruction_GetNodeDef(ctx, status);
314     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
315     NodeDef node_def;
316     node_def.ParseFromArray(node_def_buf->data, node_def_buf->length);
317     EXPECT_EQ(node_def.op(), "TestKernelAttrGetNodeDef");
318     EXPECT_EQ(node_def.name(), "FakeNode");
319     EXPECT_EQ(node_def.device(), "FakeDevice");
320     EXPECT_EQ(node_def.attr_size(), 1);
321     const ::tensorflow::AttrValue& value = node_def.attr().at("Attr");
322     EXPECT_TRUE(value.value_case() == ::tensorflow::AttrValue::ValueCase::kI);
323     EXPECT_EQ(value.i(), 1234);
324     TF_DeleteBuffer(node_def_buf);
325     TF_DeleteStatus(status);
326     return static_cast<void*>(s);
327   };
328 
329   REGISTER_OP("TestKernelAttrGetNodeDef")
330       .Attr("Attr: int")
331       .SetShapeFn(tensorflow::shape_inference::UnknownShape);
332 
333   AttrValue v;
334   v.set_i(1234);
335   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrGetNodeDef", v);
336 }
337 
TEST_F(TestKernelAttr,String)338 TEST_F(TestKernelAttr, String) {
339   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
340     struct MyCustomKernel* s = new struct MyCustomKernel;
341     s->created = true;
342     s->compute_called = false;
343 
344     std::unique_ptr<char[]> val(new char[5]);
345     TF_Status* status = TF_NewStatus();
346     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
347                    /*expected_total_size*/ 5);
348     TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
349                                           /*max_length*/ 5, status);
350 
351     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
352     EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
353     TF_DeleteStatus(status);
354     return static_cast<void*>(s);
355   };
356 
357   AttrValue v;
358   v.set_s("bunny");
359   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrString", v);
360 }
361 
TEST_F(TestKernelAttr,StringList)362 TEST_F(TestKernelAttr, StringList) {
363   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
364     struct MyCustomKernel* s = new struct MyCustomKernel;
365     s->created = true;
366     s->compute_called = false;
367 
368     std::vector<string> list = {"bugs", "bunny", "duck"};
369     int list_total_size = 0;
370     for (const auto& s : list) {
371       list_total_size += s.size();
372     }
373 
374     TF_Status* status = TF_NewStatus();
375     std::unique_ptr<char*[]> values(new char*[list.size()]);
376     std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
377     std::unique_ptr<char[]> storage(new char[list_total_size]);
378     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
379                    /*expected_total_size*/ list_total_size);
380     TF_OpKernelConstruction_GetAttrStringList(
381         ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
382         list_total_size, status);
383     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
384 
385     for (size_t i = 0; i < list.size(); ++i) {
386       EXPECT_EQ(list[i].size(), lens[i]) << i;
387       EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
388           << i;
389     }
390     TF_DeleteStatus(status);
391     return static_cast<void*>(s);
392   };
393 
394   AttrValue v;
395   std::string attr_in[] = {"bugs", "bunny", "duck"};
396   SetAttrValue(gtl::ArraySlice<std::string>(attr_in, 3), &v);
397   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrStringList", v);
398 }
399 
TEST_F(TestKernelAttr,Tensor)400 TEST_F(TestKernelAttr, Tensor) {
401   struct TensorProtoHelpers {
402    public:
403     static ::tensorflow::TensorProto GenerateTensorProto() {
404       ::tensorflow::TensorProto tensor_proto;
405       tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
406       tensor_proto.mutable_tensor_shape()->add_dim()->set_size(3);
407       tensor_proto.set_dtype(DT_INT32);
408       tensor_proto.add_int_val(1);
409       tensor_proto.add_int_val(2);
410       tensor_proto.add_int_val(3);
411       tensor_proto.add_int_val(4);
412       tensor_proto.add_int_val(5);
413       tensor_proto.add_int_val(6);
414       return tensor_proto;
415     }
416   };
417 
418   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
419     struct MyCustomKernel* s = new struct MyCustomKernel;
420     s->created = true;
421     s->compute_called = false;
422 
423     TF_Tensor* val;
424     TF_Status* status = TF_NewStatus();
425     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
426                    /*expected_total_size*/ -1);
427     TF_OpKernelConstruction_GetAttrTensor(ctx, "Attr", &val, status);
428     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
429 
430     ::tensorflow::Tensor expected_tensor;
431     EXPECT_TRUE(
432         expected_tensor.FromProto(TensorProtoHelpers::GenerateTensorProto()));
433 
434     ::tensorflow::Tensor actual_tensor;
435     EXPECT_TRUE(TF_TensorToTensor(val, &actual_tensor).ok());
436 
437     EXPECT_EQ(actual_tensor.tensor_data(), expected_tensor.tensor_data());
438     EXPECT_EQ(actual_tensor.shape(), expected_tensor.shape());
439     EXPECT_EQ(actual_tensor.dtype(), expected_tensor.dtype());
440 
441     TF_DeleteStatus(status);
442     TF_DeleteTensor(val);
443     return static_cast<void*>(s);
444   };
445 
446   AttrValue v;
447   ::tensorflow::TensorProto* tensor_proto = v.mutable_tensor();
448   *tensor_proto = TensorProtoHelpers::GenerateTensorProto();
449 
450   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTensor", v);
451 }
452 
TEST_F(TestKernelAttr,TensorList)453 TEST_F(TestKernelAttr, TensorList) {
454   struct TensorProtoHelpers {
455    public:
456     static ::tensorflow::TensorProto GenerateTensorProto1() {
457       ::tensorflow::TensorProto tensor_proto;
458       tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
459       tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
460       tensor_proto.set_dtype(DT_INT32);
461       tensor_proto.add_int_val(1);
462       tensor_proto.add_int_val(2);
463       tensor_proto.add_int_val(3);
464       tensor_proto.add_int_val(4);
465       return tensor_proto;
466     }
467 
468     static ::tensorflow::TensorProto GenerateTensorProto2() {
469       ::tensorflow::TensorProto tensor_proto;
470       tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
471       tensor_proto.mutable_tensor_shape()->add_dim()->set_size(3);
472       tensor_proto.set_dtype(DT_FLOAT);
473       tensor_proto.add_float_val(5.0f);
474       tensor_proto.add_float_val(6.0f);
475       tensor_proto.add_float_val(7.0f);
476       tensor_proto.add_float_val(8.0f);
477       tensor_proto.add_float_val(9.0f);
478       tensor_proto.add_float_val(10.0f);
479       return tensor_proto;
480     }
481   };
482 
483   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
484     struct MyCustomKernel* s = new struct MyCustomKernel;
485     s->created = true;
486     s->compute_called = false;
487 
488     const size_t list_size = 2;
489     TF_Tensor* values[list_size];
490 
491     TF_Status* status = TF_NewStatus();
492     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
493                    /*expected_total_size*/ -1);
494     TF_OpKernelConstruction_GetAttrTensorList(ctx, "Attr", values, list_size,
495                                               status);
496     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
497 
498     ::tensorflow::Tensor expected_tensor1;
499     EXPECT_TRUE(
500         expected_tensor1.FromProto(TensorProtoHelpers::GenerateTensorProto1()));
501 
502     ::tensorflow::Tensor actual_tensor1;
503     EXPECT_TRUE(TF_TensorToTensor(values[0], &actual_tensor1).ok());
504 
505     EXPECT_EQ(actual_tensor1.tensor_data(), expected_tensor1.tensor_data());
506     EXPECT_EQ(actual_tensor1.shape(), expected_tensor1.shape());
507     EXPECT_EQ(actual_tensor1.dtype(), expected_tensor1.dtype());
508 
509     ::tensorflow::Tensor expected_tensor2;
510     EXPECT_TRUE(
511         expected_tensor2.FromProto(TensorProtoHelpers::GenerateTensorProto2()));
512 
513     ::tensorflow::Tensor actual_tensor2;
514     EXPECT_TRUE(TF_TensorToTensor(values[1], &actual_tensor2).ok());
515 
516     EXPECT_EQ(actual_tensor2.tensor_data(), expected_tensor2.tensor_data());
517     EXPECT_EQ(actual_tensor2.shape(), expected_tensor2.shape());
518     EXPECT_EQ(actual_tensor2.dtype(), expected_tensor2.dtype());
519 
520     TF_DeleteStatus(status);
521     TF_DeleteTensor(values[0]);
522     TF_DeleteTensor(values[1]);
523     return static_cast<void*>(s);
524   };
525 
526   AttrValue v;
527   ::tensorflow::TensorProto* tensor_proto1 = v.mutable_list()->add_tensor();
528   *tensor_proto1 = TensorProtoHelpers::GenerateTensorProto1();
529 
530   ::tensorflow::TensorProto* tensor_proto2 = v.mutable_list()->add_tensor();
531   *tensor_proto2 = TensorProtoHelpers::GenerateTensorProto2();
532 
533   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTensorList", v);
534 }
535 
TEST_F(TestKernelAttr,Int)536 TEST_F(TestKernelAttr, Int) {
537   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
538     struct MyCustomKernel* s = new struct MyCustomKernel;
539     s->created = true;
540     s->compute_called = false;
541 
542     int64_t val;
543     TF_Status* status = TF_NewStatus();
544     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
545                    /*expected_total_size*/ -1);
546     TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
547     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
548     EXPECT_EQ(1234, val);
549     TF_DeleteStatus(status);
550     return static_cast<void*>(s);
551   };
552 
553   AttrValue v;
554   v.set_i(1234);
555   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrInt", v);
556 }
557 
TEST_F(TestKernelAttr,IntList)558 TEST_F(TestKernelAttr, IntList) {
559   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
560     struct MyCustomKernel* s = new struct MyCustomKernel;
561     s->created = true;
562     s->compute_called = false;
563 
564     const int64_t list[] = {1, 2, 3, 4};
565     const size_t list_size = TF_ARRAYSIZE(list);
566     int64_t values[list_size];
567 
568     TF_Status* status = TF_NewStatus();
569     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
570                    /*expected_total_size*/ -1);
571     TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
572                                              status);
573     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
574     EXPECT_TRUE(
575         std::equal(std::begin(list), std::end(list), std::begin(values)));
576     TF_DeleteStatus(status);
577     return static_cast<void*>(s);
578   };
579 
580   AttrValue v;
581   int64_t attr_in[] = {1, 2, 3, 4};
582   SetAttrValue(gtl::ArraySlice<int64_t>(attr_in, 4), &v);
583   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrIntList", v);
584 }
585 
TEST_F(TestKernelAttr,Float)586 TEST_F(TestKernelAttr, Float) {
587   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
588     struct MyCustomKernel* s = new struct MyCustomKernel;
589     s->created = true;
590     s->compute_called = false;
591 
592     float val;
593     TF_Status* status = TF_NewStatus();
594     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
595                    /*expected_total_size*/ -1);
596     TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
597     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
598     EXPECT_FLOAT_EQ(2.718, val);
599     TF_DeleteStatus(status);
600     return static_cast<void*>(s);
601   };
602 
603   AttrValue v;
604   v.set_f(2.718);
605   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloat", v);
606 }
607 
TEST_F(TestKernelAttr,FloatList)608 TEST_F(TestKernelAttr, FloatList) {
609   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
610     struct MyCustomKernel* s = new struct MyCustomKernel;
611     s->created = true;
612     s->compute_called = false;
613 
614     const float list[] = {1.414, 2.718, 3.1415};
615     const size_t list_size = TF_ARRAYSIZE(list);
616     float values[list_size];
617 
618     TF_Status* status = TF_NewStatus();
619     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
620                    /*expected_total_size*/ -1);
621     TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
622                                              status);
623     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
624     EXPECT_TRUE(
625         std::equal(std::begin(list), std::end(list), std::begin(values)));
626     TF_DeleteStatus(status);
627     return static_cast<void*>(s);
628   };
629 
630   AttrValue v;
631   float attr_in[] = {1.414, 2.718, 3.1415};
632   SetAttrValue(gtl::ArraySlice<float>(attr_in, 3), &v);
633   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloatList", v);
634 }
635 
TEST_F(TestKernelAttr,Bool)636 TEST_F(TestKernelAttr, Bool) {
637   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
638     struct MyCustomKernel* s = new struct MyCustomKernel;
639     s->created = true;
640     s->compute_called = false;
641 
642     unsigned char val;
643     TF_Status* status = TF_NewStatus();
644     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
645                    /*expected_total_size*/ -1);
646     TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
647     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
648     EXPECT_EQ(1, val);
649     TF_DeleteStatus(status);
650     return static_cast<void*>(s);
651   };
652 
653   AttrValue v;
654   v.set_b(true);
655   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBool", v);
656 }
657 
TEST_F(TestKernelAttr,BoolList)658 TEST_F(TestKernelAttr, BoolList) {
659   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
660     struct MyCustomKernel* s = new struct MyCustomKernel;
661     s->created = true;
662     s->compute_called = false;
663 
664     const unsigned char list[] = {1, 0, 1, 0};
665     const size_t list_size = TF_ARRAYSIZE(list);
666     unsigned char values[list_size];
667 
668     TF_Status* status = TF_NewStatus();
669     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
670                    /*expected_total_size*/ -1);
671     TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
672                                             status);
673     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
674     EXPECT_TRUE(
675         std::equal(std::begin(list), std::end(list), std::begin(values)));
676     TF_DeleteStatus(status);
677     return static_cast<void*>(s);
678   };
679 
680   AttrValue v;
681   bool attr_in[] = {true, false, true, false};
682   SetAttrValue(gtl::ArraySlice<bool>(attr_in, 4), &v);
683   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBoolList", v);
684 }
685 
TEST_F(TestKernelAttr,Type)686 TEST_F(TestKernelAttr, Type) {
687   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
688     struct MyCustomKernel* s = new struct MyCustomKernel;
689     s->created = true;
690     s->compute_called = false;
691 
692     TF_DataType val;
693     TF_Status* status = TF_NewStatus();
694     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
695                    /*expected_total_size*/ -1);
696     TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
697     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
698     EXPECT_EQ(TF_FLOAT, val);
699     TF_DeleteStatus(status);
700     return static_cast<void*>(s);
701   };
702 
703   AttrValue v;
704   v.set_type(DT_FLOAT);
705   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrType", v);
706 }
707 
TEST_F(TestKernelAttr,TypeList)708 TEST_F(TestKernelAttr, TypeList) {
709   auto my_create_func = [](TF_OpKernelConstruction* ctx) {
710     struct MyCustomKernel* s = new struct MyCustomKernel;
711     s->created = true;
712     s->compute_called = false;
713 
714     const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
715     const size_t list_size = TF_ARRAYSIZE(list);
716     TF_DataType values[list_size];
717 
718     TF_Status* status = TF_NewStatus();
719     EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
720                    /*expected_total_size*/ -1);
721     TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
722                                             status);
723     EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
724     EXPECT_TRUE(
725         std::equal(std::begin(list), std::end(list), std::begin(values)));
726     TF_DeleteStatus(status);
727     return static_cast<void*>(s);
728   };
729 
730   AttrValue v;
731   DataType attr_in[] = {DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128};
732   SetAttrValue(gtl::ArraySlice<DataType>(attr_in, 4), &v);
733   CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTypeList", v);
734 }
735 #undef EXPECT_TF_SIZE
736 
737 class DummyDevice : public DeviceBase {
738  public:
DummyDevice(Env * env)739   explicit DummyDevice(Env* env) : DeviceBase(env) {}
GetAllocator(AllocatorAttributes)740   Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
741     return cpu_allocator();
742   }
743 };
744 
TEST(TestKernel,TestInputAndOutputCount)745 TEST(TestKernel, TestInputAndOutputCount) {
746   const char* node_name = "InputOutputCounterKernel";
747   const char* op_name = "BarOp";
748   const char* device_name = "FakeDeviceName2";
749 
750   REGISTER_OP(op_name)
751       .Input("input1: double")
752       .Input("input2: uint8")
753       .Output("output1: uint8")
754       .Attr("SomeDataTypeAttr: type");
755 
756   static int num_inputs = 0;
757   static int num_outputs = 0;
758 
759   // A kernel whose Compute function has a side-effect of updating num_inputs
760   // and num_outputs. Various functions on TF_OpKernelContext are also
761   // exercised.
762   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
763     num_inputs = TF_NumInputs(ctx);
764     num_outputs = TF_NumOutputs(ctx);
765 
766     TF_Tensor* input = nullptr;
767     TF_Status* s = TF_NewStatus();
768     TF_GetInput(ctx, 0, &input, s);
769     EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s);
770     EXPECT_EQ(123, *static_cast<tensorflow::uint8*>(TF_TensorData(input)));
771     TF_GetInput(ctx, -1, &input, s);
772     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
773     TF_GetInput(ctx, 3, &input, s);
774     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
775 
776     // Copy the input tensor to output.
777     TF_SetOutput(ctx, 0, input, s);
778     EXPECT_EQ(TF_OK, TF_GetCode(s));
779 
780     TF_SetOutput(ctx, 24, input, s);
781     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
782 
783     EXPECT_EQ(TF_UINT8, TF_ExpectedOutputDataType(ctx, 0));
784 
785     EXPECT_DEATH({ TF_ExpectedOutputDataType(ctx, 1); },
786                  "Check failed: i < cc_ctx->num_outputs");
787 
788     EXPECT_DEATH({ TF_ExpectedOutputDataType(ctx, -1); },
789                  "Check failed: i >= 0");
790 
791     TF_DeleteStatus(s);
792     if (input != nullptr) {
793       TF_DeleteTensor(input);
794     }
795   };
796 
797   TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
798                                                   my_compute_func, nullptr);
799 
800   {
801     TF_Status* status = TF_NewStatus();
802     TF_RegisterKernelBuilder(node_name, builder, status);
803     EXPECT_EQ(TF_OK, TF_GetCode(status));
804     TF_DeleteStatus(status);
805   }
806 
807   {
808     OpKernelContext::Params p;
809     DummyDevice dummy_device(nullptr);
810     p.device = &dummy_device;
811     p.step_id = 43;
812 
813     Tensor t(tensorflow::uint8(123));
814 
815     gtl::InlinedVector<TensorValue, 4> inputs;
816     // Simulate 2 inputs
817     inputs.emplace_back(&t);
818     inputs.emplace_back();
819     p.inputs = inputs;
820 
821     Status status;
822     std::unique_ptr<OpKernel> kernel =
823         GetFakeKernel(device_name, op_name, node_name, &status);
824     TF_EXPECT_OK(status);
825     ASSERT_NE(nullptr, kernel.get());
826 
827     p.op_kernel = kernel.get();
828     OpKernelContext ctx(&p);
829     kernel->Compute(&ctx);
830 
831     ASSERT_EQ(2, num_inputs);
832     ASSERT_EQ(1, num_outputs);
833     ASSERT_EQ(123, ctx.mutable_output(0)->scalar<tensorflow::uint8>()());
834   }
835 }
836 
TEST(TestKernel,DeleteKernelBuilderIsOkOnNull)837 TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) {
838   TF_DeleteKernelBuilder(nullptr);
839 }
840 
ExpectedString(const char * type)841 std::string ExpectedString(const char* type) {
842   const auto format_str = R"str(kernel {
843   op: "TypeOp%s"
844   device_type: "FakeDeviceName1"
845   constraint {
846     name: "T"
847     allowed_values {
848       list {
849         type: %s
850       }
851     }
852   }
853 }
854 )str";
855   return absl::StrFormat(format_str, type, type);
856 }
857 
858 #define TEST_KERNEL_TYPE_CONSTRAINT(tf_type, dtype)                          \
859   TEST(TestKernel, TestTypeConstraint##tf_type) {                            \
860     const char* node_name = "SomeNodeName";                                  \
861     const char* op_name = "TypeOp" #dtype;                                   \
862     const char* device_name = "FakeDeviceName1";                             \
863                                                                              \
864     REGISTER_OP(op_name)                                                     \
865         .Input("input1: double")                                             \
866         .Input("input2: uint8")                                              \
867         .Output("output1: uint8")                                            \
868         .Attr("T: type");                                                    \
869                                                                              \
870     TF_KernelBuilder* builder = TF_NewKernelBuilder(                         \
871         op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); \
872     TF_Status* status = TF_NewStatus();                                      \
873     TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::tf_type,      \
874                                     status);                                 \
875     EXPECT_EQ(TF_OK, TF_GetCode(status));                                    \
876     TF_RegisterKernelBuilder(node_name, builder, status);                    \
877     EXPECT_EQ(TF_OK, TF_GetCode(status));                                    \
878                                                                              \
879     TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);          \
880     EXPECT_EQ(TF_OK, TF_GetCode(status));                                    \
881     KernelList list;                                                         \
882     list.ParseFromArray(buf->data, buf->length);                             \
883     KernelList expected_proto;                                               \
884     protobuf::TextFormat::ParseFromString(ExpectedString(#dtype),            \
885                                           &expected_proto);                  \
886     ASSERT_EQ(expected_proto.DebugString(), list.DebugString());             \
887                                                                              \
888     TF_DeleteBuffer(buf);                                                    \
889     TF_DeleteStatus(status);                                                 \
890     TF_DeleteKernelBuilder(builder);                                         \
891     ASSERT_TRUE(delete_called);                                              \
892   }
893 
894 TEST_KERNEL_TYPE_CONSTRAINT(TF_HALF, DT_HALF);
895 TEST_KERNEL_TYPE_CONSTRAINT(TF_BFLOAT16, DT_BFLOAT16);
896 TEST_KERNEL_TYPE_CONSTRAINT(TF_FLOAT, DT_FLOAT);
897 TEST_KERNEL_TYPE_CONSTRAINT(TF_DOUBLE, DT_DOUBLE);
898 TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT64, DT_UINT64);
899 TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT32, DT_UINT32);
900 TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT16, DT_UINT16);
901 TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT8, DT_UINT8);
902 TEST_KERNEL_TYPE_CONSTRAINT(TF_INT8, DT_INT8);
903 TEST_KERNEL_TYPE_CONSTRAINT(TF_INT32, DT_INT32);
904 TEST_KERNEL_TYPE_CONSTRAINT(TF_COMPLEX64, DT_COMPLEX64);
905 TEST_KERNEL_TYPE_CONSTRAINT(TF_COMPLEX128, DT_COMPLEX128);
906 TEST_KERNEL_TYPE_CONSTRAINT(TF_QINT8, DT_QINT8);
907 TEST_KERNEL_TYPE_CONSTRAINT(TF_QUINT8, DT_QUINT8);
908 TEST_KERNEL_TYPE_CONSTRAINT(TF_QINT32, DT_QINT32);
909 TEST_KERNEL_TYPE_CONSTRAINT(TF_QINT16, DT_QINT16);
910 TEST_KERNEL_TYPE_CONSTRAINT(TF_QUINT16, DT_QUINT16);
911 
TEST(TestKernel,TestHostMemory)912 TEST(TestKernel, TestHostMemory) {
913   const char* node_name = "SomeNodeName";
914   const char* op_name = "HostMemoryOp";
915   const char* device_name = "FakeDeviceName1";
916 
917   REGISTER_OP(op_name)
918       .Input("input1: double")
919       .Input("input2: uint8")
920       .Output("output1: uint8")
921       .Output("output2: uint8")
922       .Attr("T: type");
923 
924   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
925     MyComputeFunc(kernel, ctx);
926 
927     TF_Status* status = TF_NewStatus();
928 
929     TF_SetStatus(status, TF_OK, "");
930     EXPECT_EQ(false, TF_IsHostMemoryInput(ctx, 0, status));
931     EXPECT_EQ(TF_OK, TF_GetCode(status));
932 
933     TF_SetStatus(status, TF_OK, "");
934     EXPECT_EQ(true, TF_IsHostMemoryInput(ctx, 1, status));
935     EXPECT_EQ(TF_OK, TF_GetCode(status));
936 
937     TF_SetStatus(status, TF_OK, "");
938     EXPECT_EQ(true, TF_IsHostMemoryOutput(ctx, 0, status));
939     EXPECT_EQ(TF_OK, TF_GetCode(status));
940 
941     TF_SetStatus(status, TF_OK, "");
942     EXPECT_EQ(false, TF_IsHostMemoryOutput(ctx, 1, status));
943     EXPECT_EQ(TF_OK, TF_GetCode(status));
944 
945     TF_SetStatus(status, TF_OK, "");
946     TF_IsHostMemoryInput(ctx, -1, status);
947     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(status));
948 
949     TF_SetStatus(status, TF_OK, "");
950     TF_IsHostMemoryInput(ctx, 2, status);
951     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(status));
952 
953     TF_SetStatus(status, TF_OK, "");
954     TF_IsHostMemoryOutput(ctx, -1, status);
955     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(status));
956 
957     TF_SetStatus(status, TF_OK, "");
958     TF_IsHostMemoryOutput(ctx, 2, status);
959     EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(status));
960 
961     TF_DeleteStatus(status);
962   };
963 
964   TF_KernelBuilder* builder = TF_NewKernelBuilder(
965       op_name, device_name, &MyCreateFunc, my_compute_func, &MyDeleteFunc);
966   TF_KernelBuilder_HostMemory(builder, "input2");
967   TF_KernelBuilder_HostMemory(builder, "output1");
968   TF_Status* status = TF_NewStatus();
969   TF_RegisterKernelBuilder(node_name, builder, status);
970   EXPECT_EQ(TF_OK, TF_GetCode(status));
971 
972   TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
973   EXPECT_EQ(TF_OK, TF_GetCode(status));
974   KernelList list;
975   list.ParseFromArray(buf->data, buf->length);
976   KernelList expected_proto;
977   protobuf::TextFormat::ParseFromString(
978       R"str(kernel {
979   op: "HostMemoryOp"
980   device_type: "FakeDeviceName1"
981   host_memory_arg: "input2"
982   host_memory_arg: "output1"
983 }
984 )str",
985       &expected_proto);
986   ASSERT_EQ(list.DebugString(), expected_proto.DebugString());
987 
988   TF_DeleteBuffer(buf);
989   TF_DeleteStatus(status);
990   TF_DeleteKernelBuilder(builder);
991   ASSERT_TRUE(delete_called);
992 }
993 
994 class DeviceKernelOpTest : public OpsTestBase {
995  protected:
SetupOp(const char * op_name,const char * node_name,void (* compute_func)(void *,TF_OpKernelContext *))996   void SetupOp(const char* op_name, const char* node_name,
997                void (*compute_func)(void*, TF_OpKernelContext*)) {
998     TF_KernelBuilder* builder = TF_NewKernelBuilder(
999         op_name, device_name_, nullptr, compute_func, nullptr);
1000     TF_Status* status = TF_NewStatus();
1001     TF_RegisterKernelBuilder(node_name, builder, status);
1002     EXPECT_EQ(TF_OK, TF_GetCode(status));
1003     TF_DeleteStatus(status);
1004 
1005 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1006     std::unique_ptr<Device> device(
1007         DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0"));
1008     OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
1009 #endif
1010     TF_ASSERT_OK(NodeDefBuilder(op_name, op_name).Finalize(node_def()));
1011     TF_ASSERT_OK(InitOp());
1012   }
1013 
1014 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1015   const char* device_name_ = tensorflow::DEVICE_GPU;
1016 #else
1017   const char* device_name_ = tensorflow::DEVICE_CPU;
1018 #endif
1019 };
1020 
1021 // Validates that the tensor has shape and type corresponding to
1022 // dims and dtype.
1023 void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
1024                      TF_DataType dtype);
1025 
1026 // Copies data of length tensor_size_bytes from values to tensor.
1027 template <typename T>
1028 void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
1029                      TF_OpKernelContext* ctx);
1030 
1031 REGISTER_OP("StreamOp").Output("output1: float");
1032 
TEST_F(DeviceKernelOpTest,TestStream)1033 TEST_F(DeviceKernelOpTest, TestStream) {
1034   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1035     TF_Status* s = TF_NewStatus();
1036     SP_Stream stream = TF_GetStream(ctx, s);
1037     // Stream is always null if device is not a pluggable device. More test
1038     // cases will be added when pluggable device mechanism is supported.
1039     EXPECT_EQ(stream, nullptr);
1040     EXPECT_NE(TF_OK, TF_GetCode(s));
1041     TF_DeleteStatus(s);
1042   };
1043 
1044   SetupOp("StreamOp", "StreamOp", my_compute_func);
1045   TF_ASSERT_OK(RunOpKernel());
1046 }
1047 
1048 REGISTER_OP("AllocateOutputOp1").Output("output1: float");
1049 
TEST_F(DeviceKernelOpTest,TestAllocateOutputSizeOne)1050 TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
1051   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1052     // Allocate output
1053     TF_Status* s = TF_NewStatus();
1054     int64_t dim = 1;
1055     size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
1056     TF_Tensor* output = TF_AllocateOutput(
1057         /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
1058         /*num_dims=*/1, /*len=*/tensor_size_bytes, s);
1059     validate_tensor(output, &dim, 1, TF_FLOAT);
1060 
1061     // Set output to 3
1062     float values[1] = {3.0f};
1063     set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
1064     TF_DeleteStatus(s);
1065     TF_DeleteTensor(output);
1066   };
1067 
1068   SetupOp("AllocateOutputOp1", "AllocateOutput1", my_compute_func);
1069 
1070   TF_ASSERT_OK(RunOpKernel());
1071   Tensor* output = GetOutput(0);
1072   EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
1073             output->DebugString(100));
1074 }
1075 
1076 REGISTER_OP("AllocateOutputOp0").Output("output1: float");
1077 
TEST_F(DeviceKernelOpTest,TestAllocateEmptyOutput)1078 TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) {
1079   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1080     TF_Status* s = TF_NewStatus();
1081     // Allocate empty output
1082     int64_t dim = 0;
1083     TF_Tensor* output = TF_AllocateOutput(
1084         /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
1085         /*num_dims=*/1, /*len=*/0, s);
1086     EXPECT_EQ(TF_OK, TF_GetCode(s));
1087     validate_tensor(output, &dim, 1, TF_FLOAT);
1088     TF_DeleteStatus(s);
1089     TF_DeleteTensor(output);
1090   };
1091 
1092   SetupOp("AllocateOutputOp0", "AllocateOutput0", my_compute_func);
1093 
1094   TF_ASSERT_OK(RunOpKernel());
1095   Tensor* output = GetOutput(0);
1096   EXPECT_EQ("Tensor<type: float shape: [0] values: >",
1097             output->DebugString(100));
1098 }
1099 
1100 REGISTER_OP("AllocateOutputOp2x3").Output("output1: float");
1101 
TEST_F(DeviceKernelOpTest,TestAllocateOutputSize2x3)1102 TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
1103   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1104     TF_Status* s = TF_NewStatus();
1105     // Allocate 2x3 output
1106     int64_t dim[2] = {2, 3};
1107     size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT) * 6;
1108     TF_Tensor* output = TF_AllocateOutput(
1109         /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
1110         /*num_dims=*/2, /*len=*/tensor_size_bytes, s);
1111     EXPECT_EQ(TF_OK, TF_GetCode(s));
1112     validate_tensor(output, dim, 2, TF_FLOAT);
1113 
1114     // Set output to [1 2 3 4 5 6]
1115     float values[6] = {1, 2, 3, 4, 5, 6};
1116     set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
1117     TF_DeleteStatus(s);
1118     TF_DeleteTensor(output);
1119   };
1120 
1121   SetupOp("AllocateOutputOp2x3", "AllocateOutput2x3", my_compute_func);
1122 
1123   TF_ASSERT_OK(RunOpKernel());
1124   Tensor* output = GetOutput(0);
1125   EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
1126             output->DebugString(100));
1127 }
1128 
1129 REGISTER_OP("AllocateTempOp1").Output("output1: float");
1130 
TEST_F(DeviceKernelOpTest,TestAllocateTempSizeOne)1131 TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) {
1132   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1133     // Allocate scalar TF_Tensor
1134     TF_Status* s = TF_NewStatus();
1135     int64_t dim = 1;
1136     TF_AllocatorAttributes alloc_attrs;
1137     alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
1138 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1139     alloc_attrs.on_host = 0;
1140 #else
1141     alloc_attrs.on_host = 1;
1142 #endif
1143     TF_Tensor* output = TF_AllocateTemp(
1144         /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
1145         /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
1146     size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
1147     EXPECT_EQ(TF_OK, TF_GetCode(s));
1148     validate_tensor(output, &dim, 1, TF_FLOAT);
1149 
1150     // Set TF_Tensor value to 3
1151     float values[1] = {3.0f};
1152     set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
1153     TF_SetOutput(ctx, 0, output, s);
1154     TF_DeleteStatus(s);
1155     TF_DeleteTensor(output);
1156   };
1157 
1158   SetupOp("AllocateTempOp1", "AllocateTemp1", my_compute_func);
1159 
1160   TF_ASSERT_OK(RunOpKernel());
1161   Tensor* output = GetOutput(0);
1162   EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
1163             output->DebugString(100));
1164 }
1165 
1166 REGISTER_OP("AllocateTempOp0").Output("output1: float");
1167 
TEST_F(DeviceKernelOpTest,TestAllocateTempEmpty)1168 TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) {
1169   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1170     TF_Status* s = TF_NewStatus();
1171     // Allocate empty TF_Tensor
1172     int64_t dim = 0;
1173     TF_AllocatorAttributes alloc_attrs;
1174     alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
1175 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1176     alloc_attrs.on_host = 0;
1177 #else
1178     alloc_attrs.on_host = 1;
1179 #endif
1180     TF_Tensor* output = TF_AllocateTemp(
1181         /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
1182         /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
1183     EXPECT_EQ(TF_OK, TF_GetCode(s));
1184     validate_tensor(output, &dim, 1, TF_FLOAT);
1185     TF_SetOutput(ctx, 0, output, s);
1186     TF_DeleteStatus(s);
1187     TF_DeleteTensor(output);
1188   };
1189 
1190   SetupOp("AllocateTempOp0", "AllocateTemp0", my_compute_func);
1191 
1192   TF_ASSERT_OK(RunOpKernel());
1193   Tensor* output = GetOutput(0);
1194   EXPECT_EQ("Tensor<type: float shape: [0] values: >",
1195             output->DebugString(100));
1196 }
1197 
1198 REGISTER_OP("AllocateTempOp2x3").Output("output1: float");
1199 
TEST_F(DeviceKernelOpTest,TestAllocateTempSize2x3)1200 TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
1201   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1202     TF_Status* s = TF_NewStatus();
1203     size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
1204     // Allocate 2x3 TF_Tensor
1205     int64_t dim[2] = {2, 3};
1206     TF_AllocatorAttributes alloc_attrs;
1207     alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
1208 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1209     alloc_attrs.on_host = 0;
1210 #else
1211     alloc_attrs.on_host = 1;
1212 #endif
1213     TF_Tensor* output = TF_AllocateTemp(
1214         /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim,
1215         /*num_dims=*/2, /*allocator_attributes*/ &alloc_attrs, s);
1216     EXPECT_EQ(TF_OK, TF_GetCode(s));
1217     validate_tensor(output, dim, 2, TF_FLOAT);
1218 
1219     // Set TF_Tensor values to [1 2 3 4 5 6]
1220     float values[6] = {1, 2, 3, 4, 5, 6};
1221     set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
1222     TF_SetOutput(ctx, 0, output, s);
1223     TF_DeleteStatus(s);
1224     TF_DeleteTensor(output);
1225   };
1226 
1227   SetupOp("AllocateTempOp2x3", "AllocateTempOp2x3", my_compute_func);
1228 
1229   TF_ASSERT_OK(RunOpKernel());
1230   Tensor* output = GetOutput(0);
1231   EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
1232             output->DebugString(100));
1233 }
1234 
1235 REGISTER_OP("DoNothingOp")
1236     .Input("input1: float")
1237     .Input("input2: float")
1238     .Attr("NumInput3: int >= 0")
1239     .Input("input3: NumInput3 * float")
1240     .Output("output1: float")
1241     .Attr("SomeDataTypeAttr: type");
1242 
TEST_F(DeviceKernelOpTest,TestGetKernelInfo)1243 TEST_F(DeviceKernelOpTest, TestGetKernelInfo) {
1244   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1245     TF_Status* s = TF_NewStatus();
1246     int64_t dim[1] = {1};
1247     TF_AllocatorAttributes alloc_attrs;
1248     alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
1249 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1250     alloc_attrs.on_host = 0;
1251 #else
1252     alloc_attrs.on_host = 1;
1253 #endif
1254 
1255     // Test if the C API returns expected strings.
1256     TF_StringView sv = TF_GetOpKernelName(ctx);
1257     EXPECT_STREQ(sv.data, "TestGetKernelInfoNode");
1258 
1259     sv = TF_GetOpKernelRequestedInput(ctx, 0);
1260     EXPECT_STREQ(sv.data, "input1");
1261 
1262     sv = TF_GetOpKernelRequestedInput(ctx, 1);
1263     EXPECT_STREQ(sv.data, "input2");
1264 
1265     TF_InputRange_Args args;
1266     args.status = s;
1267     TF_InputRange(ctx, "input3", &args);
1268     EXPECT_EQ(TF_OK, TF_GetCode(s));
1269     EXPECT_EQ(args.start, 2);
1270     EXPECT_EQ(args.stop, 5);
1271 
1272     TF_Tensor* output = TF_AllocateTemp(
1273         /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim,
1274         /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
1275     TF_SetOutput(ctx, 0, output, s);
1276     TF_DeleteStatus(s);
1277     TF_DeleteTensor(output);
1278   };
1279 
1280   const char* node_name = "TestGetKernelInfoNode";
1281   const char* op_name = "DoNothingOp";
1282   const char* device_name = "FakeDeviceName";
1283   TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
1284                                                   my_compute_func, nullptr);
1285 
1286   TF_Status* status = TF_NewStatus();
1287   TF_RegisterKernelBuilder(node_name, builder, status);
1288   EXPECT_EQ(TF_OK, TF_GetCode(status));
1289   TF_DeleteStatus(status);
1290 
1291   {
1292     OpKernelContext::Params p;
1293     DummyDevice dummy_device(nullptr);
1294     p.device = &dummy_device;
1295     AllocatorAttributes alloc_attrs;
1296     p.output_attr_array = &alloc_attrs;
1297 
1298     gtl::InlinedVector<TensorValue, 4> inputs;
1299     Tensor t0(1.0f);
1300     Tensor t1(2.0f);
1301     Tensor t2_0(2.0f);
1302     Tensor t2_1(2.1f);
1303     Tensor t2_2(2.2f);
1304     inputs.emplace_back(&t0);
1305     inputs.emplace_back(&t1);
1306     inputs.emplace_back(&t2_0);
1307     inputs.emplace_back(&t2_1);
1308     inputs.emplace_back(&t2_2);
1309 
1310     Status status;
1311     std::unique_ptr<OpKernel> kernel =
1312         GetFakeKernel2(device_name, op_name, node_name, &status);
1313     TF_EXPECT_OK(status);
1314     ASSERT_NE(nullptr, kernel.get());
1315 
1316     p.op_kernel = kernel.get();
1317     p.inputs = inputs;
1318     OpKernelContext ctx(&p);
1319     kernel->Compute(&ctx);
1320   }
1321 }
1322 
TEST_F(DeviceKernelOpTest,TestForwardInputOrAllocateOutput)1323 TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
1324   const char* node_name = "TestForwardInputOrAllocateOutputKernel";
1325   const char* op_name = "BazOp";
1326   const char* device_name = "FakeDeviceName";
1327 
1328   REGISTER_OP(op_name)
1329       .Input("input1: float")
1330       .Input("input2: float")
1331       .Output("output1: float")
1332       .Attr("SomeDataTypeAttr: type");
1333 
1334   // A kernel whose Compute function that forwards a scalar input to output
1335   auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1336     TF_Status* s = TF_NewStatus();
1337     int candidate_input_indices[1] = {0};
1338     int forwarded_input;
1339     int64_t output_dims[1] = {};
1340     TF_Tensor* output = TF_ForwardInputOrAllocateOutput(
1341         /*context=*/ctx, candidate_input_indices,
1342         /*num_candidate_input_indices=*/1,
1343         /*output_index=*/0, output_dims, /*output_num_dims=*/0,
1344         &forwarded_input, /*status=*/s);
1345     EXPECT_EQ(TF_OK, TF_GetCode(s));
1346     EXPECT_EQ(forwarded_input, 0);
1347     EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
1348     EXPECT_EQ(0, TF_NumDims(output));
1349     TF_DeleteStatus(s);
1350     TF_DeleteTensor(output);
1351   };
1352 
1353   TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
1354                                                   my_compute_func, nullptr);
1355 
1356   {
1357     TF_Status* status = TF_NewStatus();
1358     TF_RegisterKernelBuilder(node_name, builder, status);
1359     EXPECT_EQ(TF_OK, TF_GetCode(status));
1360     TF_DeleteStatus(status);
1361   }
1362 
1363   {
1364     OpKernelContext::Params p;
1365     DummyDevice dummy_device(nullptr);
1366     p.device = &dummy_device;
1367     AllocatorAttributes alloc_attrs;
1368     p.output_attr_array = &alloc_attrs;
1369 
1370     Tensor t(123.0f);
1371 
1372     gtl::InlinedVector<TensorValue, 4> inputs;
1373     // GetFakeKernel requires a NodeDef with two inputs
1374     inputs.emplace_back(&t);
1375     inputs.emplace_back();
1376     p.inputs = inputs;
1377 
1378     Status status;
1379     std::unique_ptr<OpKernel> kernel =
1380         GetFakeKernel(device_name, op_name, node_name, &status);
1381     TF_EXPECT_OK(status);
1382     ASSERT_NE(nullptr, kernel.get());
1383 
1384     p.op_kernel = kernel.get();
1385     OpKernelContext ctx(&p);
1386     kernel->Compute(&ctx);
1387     ASSERT_EQ(123, ctx.mutable_output(0)->scalar<float>()());
1388   }
1389 }
1390 
validate_tensor(TF_Tensor * tensor,int64_t * dims,int64_t num_dims,TF_DataType dtype)1391 void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
1392                      TF_DataType dtype) {
1393   EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));
1394   EXPECT_EQ(num_dims, TF_NumDims(tensor));
1395   for (int i = 0; i < num_dims; ++i) {
1396     EXPECT_EQ(dims[i], TF_Dim(tensor, i));
1397   }
1398 }
1399 
1400 template <typename T>
set_tensor_data(TF_Tensor * tensor,T * values,size_t tensor_size_bytes,TF_OpKernelContext * ctx)1401 void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
1402                      TF_OpKernelContext* ctx) {
1403   T* data = reinterpret_cast<T*>(TF_TensorData(tensor));
1404 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1405   OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
1406   cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values,
1407                                                 tensor_size_bytes);
1408 #else
1409   memcpy(data, values, tensor_size_bytes);
1410 #endif
1411 }
1412 }  // namespace tensorflow
1413