1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
16 #define EIGEN_USE_GPU
17 #endif
18
19 #include "tensorflow/c/kernels.h"
20
21 #include <stddef.h>
22 #include <stdint.h>
23 #include <string.h>
24
25 #include <memory>
26 #include <string>
27 #include <utility>
28
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/str_format.h"
31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32 #include "tensorflow/c/c_api.h"
33 #include "tensorflow/c/tf_datatype.h"
34 #include "tensorflow/c/tf_status.h"
35 #include "tensorflow/c/tf_tensor.h"
36 #include "tensorflow/core/common_runtime/device.h"
37 #include "tensorflow/core/common_runtime/device_factory.h"
38 #include "tensorflow/core/framework/allocator.h"
39 #include "tensorflow/core/framework/attr_value.pb.h"
40 #include "tensorflow/core/framework/device_base.h"
41 #include "tensorflow/core/framework/kernel_def.pb.h"
42 #include "tensorflow/core/framework/node_def.pb.h"
43 #include "tensorflow/core/framework/node_def_builder.h"
44 #include "tensorflow/core/framework/op.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor_types.h"
48 #include "tensorflow/core/framework/types.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/kernels/ops_testutil.h"
51 #include "tensorflow/core/lib/core/status_test_util.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/status.h"
54 #include "tensorflow/core/platform/test.h"
55 #include "tensorflow/core/platform/types.h"
56
57 struct MyCustomKernel {
58 bool created;
59 bool compute_called;
60 };
61
62 static bool delete_called = false;
63
MyCreateFunc(TF_OpKernelConstruction * ctx)64 static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
65 struct MyCustomKernel* s = new struct MyCustomKernel;
66 s->created = true;
67 s->compute_called = false;
68
69 // Exercise attribute reads.
70 TF_DataType type;
71 TF_Status* status = TF_NewStatus();
72 TF_OpKernelConstruction_GetAttrType(ctx, "SomeDataTypeAttr", &type, status);
73 EXPECT_EQ(TF_OK, TF_GetCode(status));
74 EXPECT_EQ(TF_FLOAT, type);
75 TF_DeleteStatus(status);
76
77 // Exercise kernel NodeDef name read
78 TF_StringView name_string_view = TF_OpKernelConstruction_GetName(ctx);
79 std::string node_name = "SomeNodeName";
80 std::string candidate_node_name =
81 std::string(name_string_view.data, name_string_view.len);
82 EXPECT_EQ(node_name, candidate_node_name);
83 return s;
84 }
85
MyComputeFunc(void * kernel,TF_OpKernelContext * ctx)86 static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) {
87 struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
88 s->compute_called = true;
89 if (ctx != nullptr) {
90 EXPECT_EQ(43, TF_StepId(ctx));
91 }
92 }
93
MyDeleteFunc(void * kernel)94 static void MyDeleteFunc(void* kernel) {
95 struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
96 EXPECT_TRUE(s->created);
97 EXPECT_TRUE(s->compute_called);
98 delete_called = true;
99 delete s;
100 }
101
102 namespace tensorflow {
103 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
104
GetFakeKernel(const char * device_name,const char * op_name,const char * node_name,Status * status)105 static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
106 const char* op_name,
107 const char* node_name,
108 Status* status) {
109 NodeDef def;
110 def.set_op(op_name);
111 def.set_name(node_name);
112 def.set_device(device_name);
113 def.add_input("input1");
114 def.add_input("input2");
115
116 AttrValue v;
117 v.set_type(DataType::DT_FLOAT);
118 (*def.mutable_attr())["SomeDataTypeAttr"] = v;
119
120 return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1,
121 status);
122 }
123
GetFakeKernel2(const char * device_name,const char * op_name,const char * node_name,Status * status)124 static std::unique_ptr<OpKernel> GetFakeKernel2(const char* device_name,
125 const char* op_name,
126 const char* node_name,
127 Status* status) {
128 NodeDef def;
129 def.set_op(op_name);
130 def.set_name(node_name);
131 def.set_device(device_name);
132 def.add_input("input1");
133 def.add_input("input2");
134 def.add_input("input3");
135 def.add_input("input3");
136 def.add_input("input3");
137
138 AttrValue v0;
139 v0.set_type(DataType::DT_INT32);
140 v0.set_i(3);
141 (*def.mutable_attr())["NumInput3"] = v0;
142 AttrValue v1;
143 v1.set_type(DataType::DT_FLOAT);
144 (*def.mutable_attr())["SomeDataTypeAttr"] = v1;
145
146 return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1,
147 status);
148 }
149
150 // Tests registration of a single C kernel and checks that calls through the
151 // C/C++ boundary are being made.
TEST(TestKernel,TestRegisterKernelBuilder)152 TEST(TestKernel, TestRegisterKernelBuilder) {
153 const char* node_name = "SomeNodeName";
154 const char* op_name = "FooOp";
155 const char* device_name = "FakeDeviceName1";
156
157 REGISTER_OP(op_name)
158 .Input("input1: double")
159 .Input("input2: uint8")
160 .Output("output1: uint8")
161 .Attr("SomeDataTypeAttr: type");
162
163 TF_KernelBuilder* builder = TF_NewKernelBuilder(
164 op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
165
166 {
167 TF_Status* status = TF_NewStatus();
168 TF_RegisterKernelBuilder(node_name, builder, status);
169 EXPECT_EQ(TF_OK, TF_GetCode(status));
170 TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
171 EXPECT_EQ(TF_OK, TF_GetCode(status));
172 KernelList list;
173 list.ParseFromArray(buf->data, buf->length);
174 ASSERT_EQ(1, list.kernel_size());
175 ASSERT_EQ(device_name, list.kernel(0).device_type());
176 TF_DeleteBuffer(buf);
177 TF_DeleteStatus(status);
178 }
179
180 {
181 Status status;
182 std::unique_ptr<OpKernel> kernel =
183 GetFakeKernel(device_name, op_name, node_name, &status);
184 TF_EXPECT_OK(status);
185 ASSERT_NE(nullptr, kernel.get());
186 kernel->Compute(nullptr);
187 }
188
189 ASSERT_TRUE(delete_called);
190 }
191
TEST(TestKernel,TF_RegisterKernelBuilderWithKernelDef)192 TEST(TestKernel, TF_RegisterKernelBuilderWithKernelDef) {
193 const char* node_name = "SomeNodeName";
194 const char* op_name = "FooOp1";
195 const char* device_name = "FakeDeviceName2";
196
197 REGISTER_OP(op_name)
198 .Input("input1: double")
199 .Input("input2: uint8")
200 .Output("output1: uint8")
201 .Attr("SomeDataTypeAttr: type");
202
203 TF_KernelBuilder* builder = TF_NewKernelBuilder(
204 op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
205
206 KernelDef kernel_def;
207 kernel_def.set_op(op_name);
208 kernel_def.set_device_type(device_name);
209 std::string kernel_def_str = kernel_def.SerializePartialAsString();
210
211 {
212 TF_Status* status = TF_NewStatus();
213 TF_RegisterKernelBuilderWithKernelDef(kernel_def_str.data(), node_name,
214 builder, status);
215 EXPECT_EQ(TF_OK, TF_GetCode(status));
216 TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
217 EXPECT_EQ(TF_OK, TF_GetCode(status));
218 KernelList list;
219 list.ParseFromArray(buf->data, buf->length);
220 ASSERT_EQ(1, list.kernel_size());
221 ASSERT_EQ(device_name, list.kernel(0).device_type());
222 TF_DeleteBuffer(buf);
223 TF_DeleteStatus(status);
224 }
225
226 {
227 Status status;
228 std::unique_ptr<OpKernel> kernel =
229 GetFakeKernel(device_name, op_name, node_name, &status);
230 TF_EXPECT_OK(status);
231 ASSERT_NE(nullptr, kernel.get());
232 kernel->Compute(nullptr);
233 }
234
235 ASSERT_TRUE(delete_called);
236 }
237
238 // REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases.
239 // Registers two ops, each with a single attribute called 'Attr'.
240 // The attribute in one op will have a type 'type', the other
241 // will have list(type).
242 #define ATTR_TEST_REGISTER_OP(name, type) \
243 REGISTER_OP("TestKernelAttr" #name) \
244 .Attr("Attr: " #type) \
245 .SetShapeFn(tensorflow::shape_inference::UnknownShape); \
246 REGISTER_OP("TestKernelAttr" #name "List") \
247 .Attr("Attr: list(" #type ")") \
248 .SetShapeFn(tensorflow::shape_inference::UnknownShape)
249 ATTR_TEST_REGISTER_OP(String, string);
250 ATTR_TEST_REGISTER_OP(Int, int);
251 ATTR_TEST_REGISTER_OP(Float, float);
252 ATTR_TEST_REGISTER_OP(Bool, bool);
253 ATTR_TEST_REGISTER_OP(Type, type);
254 ATTR_TEST_REGISTER_OP(Tensor, tensor);
255 #undef ATTR_TEST_REGISTER_OP
256
257 // Helper macros for the TF_OpKernelConstruction_GetAttr* tests.
258 #define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \
259 do { \
260 int32_t list_size, total_size; \
261 TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \
262 &total_size, status); \
263 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \
264 EXPECT_EQ(expected_list_size, list_size); \
265 EXPECT_EQ(expected_total_size, total_size); \
266 } while (0)
267
268 typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*);
269 class TestKernelAttr : public ::testing::Test {
270 public:
TestKernelAttr()271 TestKernelAttr() {}
~TestKernelAttr()272 ~TestKernelAttr() override {}
273
GetFakeKernelWithAttr(const char * op_name,AttrValue v,Status * status)274 std::unique_ptr<OpKernel> GetFakeKernelWithAttr(const char* op_name,
275 AttrValue v, Status* status) {
276 NodeDef def;
277 def.set_op(op_name);
278 def.set_name("FakeNode");
279 def.set_device("FakeDevice");
280 (*def.mutable_attr())["Attr"] = v;
281 return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1,
282 status);
283 }
284
CreateAndCallKernelWithAttr(MyCreateFuncWithAttr MyCreateFuncAttr,const char * op_name,AttrValue & v)285 void CreateAndCallKernelWithAttr(MyCreateFuncWithAttr MyCreateFuncAttr,
286 const char* op_name, AttrValue& v) {
287 TF_KernelBuilder* builder = TF_NewKernelBuilder(
288 op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc);
289 {
290 TF_Status* status = TF_NewStatus();
291 TF_RegisterKernelBuilder("FakeNode", builder, status);
292 EXPECT_EQ(TF_OK, TF_GetCode(status));
293 TF_DeleteStatus(status);
294 }
295 Status status;
296 std::unique_ptr<OpKernel> kernel =
297 GetFakeKernelWithAttr(op_name, v, &status);
298 TF_EXPECT_OK(status);
299 ASSERT_NE(nullptr, kernel.get());
300 kernel->Compute(nullptr);
301
302 ASSERT_TRUE(delete_called);
303 }
304 };
305
TEST_F(TestKernelAttr,GetNodeDef)306 TEST_F(TestKernelAttr, GetNodeDef) {
307 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
308 struct MyCustomKernel* s = new struct MyCustomKernel;
309 s->created = true;
310 s->compute_called = false;
311
312 TF_Status* status = TF_NewStatus();
313 TF_Buffer* node_def_buf = TF_OpKernelConstruction_GetNodeDef(ctx, status);
314 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
315 NodeDef node_def;
316 node_def.ParseFromArray(node_def_buf->data, node_def_buf->length);
317 EXPECT_EQ(node_def.op(), "TestKernelAttrGetNodeDef");
318 EXPECT_EQ(node_def.name(), "FakeNode");
319 EXPECT_EQ(node_def.device(), "FakeDevice");
320 EXPECT_EQ(node_def.attr_size(), 1);
321 const ::tensorflow::AttrValue& value = node_def.attr().at("Attr");
322 EXPECT_TRUE(value.value_case() == ::tensorflow::AttrValue::ValueCase::kI);
323 EXPECT_EQ(value.i(), 1234);
324 TF_DeleteBuffer(node_def_buf);
325 TF_DeleteStatus(status);
326 return static_cast<void*>(s);
327 };
328
329 REGISTER_OP("TestKernelAttrGetNodeDef")
330 .Attr("Attr: int")
331 .SetShapeFn(tensorflow::shape_inference::UnknownShape);
332
333 AttrValue v;
334 v.set_i(1234);
335 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrGetNodeDef", v);
336 }
337
TEST_F(TestKernelAttr,String)338 TEST_F(TestKernelAttr, String) {
339 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
340 struct MyCustomKernel* s = new struct MyCustomKernel;
341 s->created = true;
342 s->compute_called = false;
343
344 std::unique_ptr<char[]> val(new char[5]);
345 TF_Status* status = TF_NewStatus();
346 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
347 /*expected_total_size*/ 5);
348 TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(),
349 /*max_length*/ 5, status);
350
351 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
352 EXPECT_EQ("bunny", string(static_cast<const char*>(val.get()), 5));
353 TF_DeleteStatus(status);
354 return static_cast<void*>(s);
355 };
356
357 AttrValue v;
358 v.set_s("bunny");
359 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrString", v);
360 }
361
TEST_F(TestKernelAttr,StringList)362 TEST_F(TestKernelAttr, StringList) {
363 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
364 struct MyCustomKernel* s = new struct MyCustomKernel;
365 s->created = true;
366 s->compute_called = false;
367
368 std::vector<string> list = {"bugs", "bunny", "duck"};
369 int list_total_size = 0;
370 for (const auto& s : list) {
371 list_total_size += s.size();
372 }
373
374 TF_Status* status = TF_NewStatus();
375 std::unique_ptr<char*[]> values(new char*[list.size()]);
376 std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
377 std::unique_ptr<char[]> storage(new char[list_total_size]);
378 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(),
379 /*expected_total_size*/ list_total_size);
380 TF_OpKernelConstruction_GetAttrStringList(
381 ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(),
382 list_total_size, status);
383 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
384
385 for (size_t i = 0; i < list.size(); ++i) {
386 EXPECT_EQ(list[i].size(), lens[i]) << i;
387 EXPECT_EQ(list[i], string(static_cast<const char*>(values[i]), lens[i]))
388 << i;
389 }
390 TF_DeleteStatus(status);
391 return static_cast<void*>(s);
392 };
393
394 AttrValue v;
395 std::string attr_in[] = {"bugs", "bunny", "duck"};
396 SetAttrValue(gtl::ArraySlice<std::string>(attr_in, 3), &v);
397 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrStringList", v);
398 }
399
TEST_F(TestKernelAttr,Tensor)400 TEST_F(TestKernelAttr, Tensor) {
401 struct TensorProtoHelpers {
402 public:
403 static ::tensorflow::TensorProto GenerateTensorProto() {
404 ::tensorflow::TensorProto tensor_proto;
405 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
406 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(3);
407 tensor_proto.set_dtype(DT_INT32);
408 tensor_proto.add_int_val(1);
409 tensor_proto.add_int_val(2);
410 tensor_proto.add_int_val(3);
411 tensor_proto.add_int_val(4);
412 tensor_proto.add_int_val(5);
413 tensor_proto.add_int_val(6);
414 return tensor_proto;
415 }
416 };
417
418 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
419 struct MyCustomKernel* s = new struct MyCustomKernel;
420 s->created = true;
421 s->compute_called = false;
422
423 TF_Tensor* val;
424 TF_Status* status = TF_NewStatus();
425 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
426 /*expected_total_size*/ -1);
427 TF_OpKernelConstruction_GetAttrTensor(ctx, "Attr", &val, status);
428 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
429
430 ::tensorflow::Tensor expected_tensor;
431 EXPECT_TRUE(
432 expected_tensor.FromProto(TensorProtoHelpers::GenerateTensorProto()));
433
434 ::tensorflow::Tensor actual_tensor;
435 EXPECT_TRUE(TF_TensorToTensor(val, &actual_tensor).ok());
436
437 EXPECT_EQ(actual_tensor.tensor_data(), expected_tensor.tensor_data());
438 EXPECT_EQ(actual_tensor.shape(), expected_tensor.shape());
439 EXPECT_EQ(actual_tensor.dtype(), expected_tensor.dtype());
440
441 TF_DeleteStatus(status);
442 TF_DeleteTensor(val);
443 return static_cast<void*>(s);
444 };
445
446 AttrValue v;
447 ::tensorflow::TensorProto* tensor_proto = v.mutable_tensor();
448 *tensor_proto = TensorProtoHelpers::GenerateTensorProto();
449
450 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTensor", v);
451 }
452
TEST_F(TestKernelAttr,TensorList)453 TEST_F(TestKernelAttr, TensorList) {
454 struct TensorProtoHelpers {
455 public:
456 static ::tensorflow::TensorProto GenerateTensorProto1() {
457 ::tensorflow::TensorProto tensor_proto;
458 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
459 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
460 tensor_proto.set_dtype(DT_INT32);
461 tensor_proto.add_int_val(1);
462 tensor_proto.add_int_val(2);
463 tensor_proto.add_int_val(3);
464 tensor_proto.add_int_val(4);
465 return tensor_proto;
466 }
467
468 static ::tensorflow::TensorProto GenerateTensorProto2() {
469 ::tensorflow::TensorProto tensor_proto;
470 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
471 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(3);
472 tensor_proto.set_dtype(DT_FLOAT);
473 tensor_proto.add_float_val(5.0f);
474 tensor_proto.add_float_val(6.0f);
475 tensor_proto.add_float_val(7.0f);
476 tensor_proto.add_float_val(8.0f);
477 tensor_proto.add_float_val(9.0f);
478 tensor_proto.add_float_val(10.0f);
479 return tensor_proto;
480 }
481 };
482
483 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
484 struct MyCustomKernel* s = new struct MyCustomKernel;
485 s->created = true;
486 s->compute_called = false;
487
488 const size_t list_size = 2;
489 TF_Tensor* values[list_size];
490
491 TF_Status* status = TF_NewStatus();
492 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
493 /*expected_total_size*/ -1);
494 TF_OpKernelConstruction_GetAttrTensorList(ctx, "Attr", values, list_size,
495 status);
496 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
497
498 ::tensorflow::Tensor expected_tensor1;
499 EXPECT_TRUE(
500 expected_tensor1.FromProto(TensorProtoHelpers::GenerateTensorProto1()));
501
502 ::tensorflow::Tensor actual_tensor1;
503 EXPECT_TRUE(TF_TensorToTensor(values[0], &actual_tensor1).ok());
504
505 EXPECT_EQ(actual_tensor1.tensor_data(), expected_tensor1.tensor_data());
506 EXPECT_EQ(actual_tensor1.shape(), expected_tensor1.shape());
507 EXPECT_EQ(actual_tensor1.dtype(), expected_tensor1.dtype());
508
509 ::tensorflow::Tensor expected_tensor2;
510 EXPECT_TRUE(
511 expected_tensor2.FromProto(TensorProtoHelpers::GenerateTensorProto2()));
512
513 ::tensorflow::Tensor actual_tensor2;
514 EXPECT_TRUE(TF_TensorToTensor(values[1], &actual_tensor2).ok());
515
516 EXPECT_EQ(actual_tensor2.tensor_data(), expected_tensor2.tensor_data());
517 EXPECT_EQ(actual_tensor2.shape(), expected_tensor2.shape());
518 EXPECT_EQ(actual_tensor2.dtype(), expected_tensor2.dtype());
519
520 TF_DeleteStatus(status);
521 TF_DeleteTensor(values[0]);
522 TF_DeleteTensor(values[1]);
523 return static_cast<void*>(s);
524 };
525
526 AttrValue v;
527 ::tensorflow::TensorProto* tensor_proto1 = v.mutable_list()->add_tensor();
528 *tensor_proto1 = TensorProtoHelpers::GenerateTensorProto1();
529
530 ::tensorflow::TensorProto* tensor_proto2 = v.mutable_list()->add_tensor();
531 *tensor_proto2 = TensorProtoHelpers::GenerateTensorProto2();
532
533 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTensorList", v);
534 }
535
TEST_F(TestKernelAttr,Int)536 TEST_F(TestKernelAttr, Int) {
537 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
538 struct MyCustomKernel* s = new struct MyCustomKernel;
539 s->created = true;
540 s->compute_called = false;
541
542 int64_t val;
543 TF_Status* status = TF_NewStatus();
544 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
545 /*expected_total_size*/ -1);
546 TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status);
547 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
548 EXPECT_EQ(1234, val);
549 TF_DeleteStatus(status);
550 return static_cast<void*>(s);
551 };
552
553 AttrValue v;
554 v.set_i(1234);
555 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrInt", v);
556 }
557
TEST_F(TestKernelAttr,IntList)558 TEST_F(TestKernelAttr, IntList) {
559 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
560 struct MyCustomKernel* s = new struct MyCustomKernel;
561 s->created = true;
562 s->compute_called = false;
563
564 const int64_t list[] = {1, 2, 3, 4};
565 const size_t list_size = TF_ARRAYSIZE(list);
566 int64_t values[list_size];
567
568 TF_Status* status = TF_NewStatus();
569 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
570 /*expected_total_size*/ -1);
571 TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size,
572 status);
573 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
574 EXPECT_TRUE(
575 std::equal(std::begin(list), std::end(list), std::begin(values)));
576 TF_DeleteStatus(status);
577 return static_cast<void*>(s);
578 };
579
580 AttrValue v;
581 int64_t attr_in[] = {1, 2, 3, 4};
582 SetAttrValue(gtl::ArraySlice<int64_t>(attr_in, 4), &v);
583 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrIntList", v);
584 }
585
TEST_F(TestKernelAttr,Float)586 TEST_F(TestKernelAttr, Float) {
587 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
588 struct MyCustomKernel* s = new struct MyCustomKernel;
589 s->created = true;
590 s->compute_called = false;
591
592 float val;
593 TF_Status* status = TF_NewStatus();
594 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
595 /*expected_total_size*/ -1);
596 TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status);
597 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
598 EXPECT_FLOAT_EQ(2.718, val);
599 TF_DeleteStatus(status);
600 return static_cast<void*>(s);
601 };
602
603 AttrValue v;
604 v.set_f(2.718);
605 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloat", v);
606 }
607
TEST_F(TestKernelAttr,FloatList)608 TEST_F(TestKernelAttr, FloatList) {
609 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
610 struct MyCustomKernel* s = new struct MyCustomKernel;
611 s->created = true;
612 s->compute_called = false;
613
614 const float list[] = {1.414, 2.718, 3.1415};
615 const size_t list_size = TF_ARRAYSIZE(list);
616 float values[list_size];
617
618 TF_Status* status = TF_NewStatus();
619 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
620 /*expected_total_size*/ -1);
621 TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size,
622 status);
623 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
624 EXPECT_TRUE(
625 std::equal(std::begin(list), std::end(list), std::begin(values)));
626 TF_DeleteStatus(status);
627 return static_cast<void*>(s);
628 };
629
630 AttrValue v;
631 float attr_in[] = {1.414, 2.718, 3.1415};
632 SetAttrValue(gtl::ArraySlice<float>(attr_in, 3), &v);
633 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloatList", v);
634 }
635
TEST_F(TestKernelAttr,Bool)636 TEST_F(TestKernelAttr, Bool) {
637 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
638 struct MyCustomKernel* s = new struct MyCustomKernel;
639 s->created = true;
640 s->compute_called = false;
641
642 unsigned char val;
643 TF_Status* status = TF_NewStatus();
644 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
645 /*expected_total_size*/ -1);
646 TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status);
647 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
648 EXPECT_EQ(1, val);
649 TF_DeleteStatus(status);
650 return static_cast<void*>(s);
651 };
652
653 AttrValue v;
654 v.set_b(true);
655 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBool", v);
656 }
657
TEST_F(TestKernelAttr,BoolList)658 TEST_F(TestKernelAttr, BoolList) {
659 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
660 struct MyCustomKernel* s = new struct MyCustomKernel;
661 s->created = true;
662 s->compute_called = false;
663
664 const unsigned char list[] = {1, 0, 1, 0};
665 const size_t list_size = TF_ARRAYSIZE(list);
666 unsigned char values[list_size];
667
668 TF_Status* status = TF_NewStatus();
669 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
670 /*expected_total_size*/ -1);
671 TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size,
672 status);
673 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
674 EXPECT_TRUE(
675 std::equal(std::begin(list), std::end(list), std::begin(values)));
676 TF_DeleteStatus(status);
677 return static_cast<void*>(s);
678 };
679
680 AttrValue v;
681 bool attr_in[] = {true, false, true, false};
682 SetAttrValue(gtl::ArraySlice<bool>(attr_in, 4), &v);
683 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBoolList", v);
684 }
685
TEST_F(TestKernelAttr,Type)686 TEST_F(TestKernelAttr, Type) {
687 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
688 struct MyCustomKernel* s = new struct MyCustomKernel;
689 s->created = true;
690 s->compute_called = false;
691
692 TF_DataType val;
693 TF_Status* status = TF_NewStatus();
694 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1,
695 /*expected_total_size*/ -1);
696 TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status);
697 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
698 EXPECT_EQ(TF_FLOAT, val);
699 TF_DeleteStatus(status);
700 return static_cast<void*>(s);
701 };
702
703 AttrValue v;
704 v.set_type(DT_FLOAT);
705 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrType", v);
706 }
707
TEST_F(TestKernelAttr,TypeList)708 TEST_F(TestKernelAttr, TypeList) {
709 auto my_create_func = [](TF_OpKernelConstruction* ctx) {
710 struct MyCustomKernel* s = new struct MyCustomKernel;
711 s->created = true;
712 s->compute_called = false;
713
714 const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128};
715 const size_t list_size = TF_ARRAYSIZE(list);
716 TF_DataType values[list_size];
717
718 TF_Status* status = TF_NewStatus();
719 EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size,
720 /*expected_total_size*/ -1);
721 TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size,
722 status);
723 EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
724 EXPECT_TRUE(
725 std::equal(std::begin(list), std::end(list), std::begin(values)));
726 TF_DeleteStatus(status);
727 return static_cast<void*>(s);
728 };
729
730 AttrValue v;
731 DataType attr_in[] = {DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128};
732 SetAttrValue(gtl::ArraySlice<DataType>(attr_in, 4), &v);
733 CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTypeList", v);
734 }
735 #undef EXPECT_TF_SIZE
736
737 class DummyDevice : public DeviceBase {
738 public:
DummyDevice(Env * env)739 explicit DummyDevice(Env* env) : DeviceBase(env) {}
GetAllocator(AllocatorAttributes)740 Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
741 return cpu_allocator();
742 }
743 };
744
TEST(TestKernel,TestInputAndOutputCount)745 TEST(TestKernel, TestInputAndOutputCount) {
746 const char* node_name = "InputOutputCounterKernel";
747 const char* op_name = "BarOp";
748 const char* device_name = "FakeDeviceName2";
749
750 REGISTER_OP(op_name)
751 .Input("input1: double")
752 .Input("input2: uint8")
753 .Output("output1: uint8")
754 .Attr("SomeDataTypeAttr: type");
755
756 static int num_inputs = 0;
757 static int num_outputs = 0;
758
759 // A kernel whose Compute function has a side-effect of updating num_inputs
760 // and num_outputs. Various functions on TF_OpKernelContext are also
761 // exercised.
762 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
763 num_inputs = TF_NumInputs(ctx);
764 num_outputs = TF_NumOutputs(ctx);
765
766 TF_Tensor* input = nullptr;
767 TF_Status* s = TF_NewStatus();
768 TF_GetInput(ctx, 0, &input, s);
769 EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s);
770 EXPECT_EQ(123, *static_cast<tensorflow::uint8*>(TF_TensorData(input)));
771 TF_GetInput(ctx, -1, &input, s);
772 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
773 TF_GetInput(ctx, 3, &input, s);
774 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
775
776 // Copy the input tensor to output.
777 TF_SetOutput(ctx, 0, input, s);
778 EXPECT_EQ(TF_OK, TF_GetCode(s));
779
780 TF_SetOutput(ctx, 24, input, s);
781 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
782
783 EXPECT_EQ(TF_UINT8, TF_ExpectedOutputDataType(ctx, 0));
784
785 EXPECT_DEATH({ TF_ExpectedOutputDataType(ctx, 1); },
786 "Check failed: i < cc_ctx->num_outputs");
787
788 EXPECT_DEATH({ TF_ExpectedOutputDataType(ctx, -1); },
789 "Check failed: i >= 0");
790
791 TF_DeleteStatus(s);
792 if (input != nullptr) {
793 TF_DeleteTensor(input);
794 }
795 };
796
797 TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
798 my_compute_func, nullptr);
799
800 {
801 TF_Status* status = TF_NewStatus();
802 TF_RegisterKernelBuilder(node_name, builder, status);
803 EXPECT_EQ(TF_OK, TF_GetCode(status));
804 TF_DeleteStatus(status);
805 }
806
807 {
808 OpKernelContext::Params p;
809 DummyDevice dummy_device(nullptr);
810 p.device = &dummy_device;
811 p.step_id = 43;
812
813 Tensor t(tensorflow::uint8(123));
814
815 gtl::InlinedVector<TensorValue, 4> inputs;
816 // Simulate 2 inputs
817 inputs.emplace_back(&t);
818 inputs.emplace_back();
819 p.inputs = inputs;
820
821 Status status;
822 std::unique_ptr<OpKernel> kernel =
823 GetFakeKernel(device_name, op_name, node_name, &status);
824 TF_EXPECT_OK(status);
825 ASSERT_NE(nullptr, kernel.get());
826
827 p.op_kernel = kernel.get();
828 OpKernelContext ctx(&p);
829 kernel->Compute(&ctx);
830
831 ASSERT_EQ(2, num_inputs);
832 ASSERT_EQ(1, num_outputs);
833 ASSERT_EQ(123, ctx.mutable_output(0)->scalar<tensorflow::uint8>()());
834 }
835 }
836
TEST(TestKernel,DeleteKernelBuilderIsOkOnNull)837 TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) {
838 TF_DeleteKernelBuilder(nullptr);
839 }
840
ExpectedString(const char * type)841 std::string ExpectedString(const char* type) {
842 const auto format_str = R"str(kernel {
843 op: "TypeOp%s"
844 device_type: "FakeDeviceName1"
845 constraint {
846 name: "T"
847 allowed_values {
848 list {
849 type: %s
850 }
851 }
852 }
853 }
854 )str";
855 return absl::StrFormat(format_str, type, type);
856 }
857
858 #define TEST_KERNEL_TYPE_CONSTRAINT(tf_type, dtype) \
859 TEST(TestKernel, TestTypeConstraint##tf_type) { \
860 const char* node_name = "SomeNodeName"; \
861 const char* op_name = "TypeOp" #dtype; \
862 const char* device_name = "FakeDeviceName1"; \
863 \
864 REGISTER_OP(op_name) \
865 .Input("input1: double") \
866 .Input("input2: uint8") \
867 .Output("output1: uint8") \
868 .Attr("T: type"); \
869 \
870 TF_KernelBuilder* builder = TF_NewKernelBuilder( \
871 op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); \
872 TF_Status* status = TF_NewStatus(); \
873 TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::tf_type, \
874 status); \
875 EXPECT_EQ(TF_OK, TF_GetCode(status)); \
876 TF_RegisterKernelBuilder(node_name, builder, status); \
877 EXPECT_EQ(TF_OK, TF_GetCode(status)); \
878 \
879 TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); \
880 EXPECT_EQ(TF_OK, TF_GetCode(status)); \
881 KernelList list; \
882 list.ParseFromArray(buf->data, buf->length); \
883 KernelList expected_proto; \
884 protobuf::TextFormat::ParseFromString(ExpectedString(#dtype), \
885 &expected_proto); \
886 ASSERT_EQ(expected_proto.DebugString(), list.DebugString()); \
887 \
888 TF_DeleteBuffer(buf); \
889 TF_DeleteStatus(status); \
890 TF_DeleteKernelBuilder(builder); \
891 ASSERT_TRUE(delete_called); \
892 }
893
894 TEST_KERNEL_TYPE_CONSTRAINT(TF_HALF, DT_HALF);
895 TEST_KERNEL_TYPE_CONSTRAINT(TF_BFLOAT16, DT_BFLOAT16);
896 TEST_KERNEL_TYPE_CONSTRAINT(TF_FLOAT, DT_FLOAT);
897 TEST_KERNEL_TYPE_CONSTRAINT(TF_DOUBLE, DT_DOUBLE);
898 TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT64, DT_UINT64);
899 TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT32, DT_UINT32);
900 TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT16, DT_UINT16);
901 TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT8, DT_UINT8);
902 TEST_KERNEL_TYPE_CONSTRAINT(TF_INT8, DT_INT8);
903 TEST_KERNEL_TYPE_CONSTRAINT(TF_INT32, DT_INT32);
904 TEST_KERNEL_TYPE_CONSTRAINT(TF_COMPLEX64, DT_COMPLEX64);
905 TEST_KERNEL_TYPE_CONSTRAINT(TF_COMPLEX128, DT_COMPLEX128);
906 TEST_KERNEL_TYPE_CONSTRAINT(TF_QINT8, DT_QINT8);
907 TEST_KERNEL_TYPE_CONSTRAINT(TF_QUINT8, DT_QUINT8);
908 TEST_KERNEL_TYPE_CONSTRAINT(TF_QINT32, DT_QINT32);
909 TEST_KERNEL_TYPE_CONSTRAINT(TF_QINT16, DT_QINT16);
910 TEST_KERNEL_TYPE_CONSTRAINT(TF_QUINT16, DT_QUINT16);
911
TEST(TestKernel,TestHostMemory)912 TEST(TestKernel, TestHostMemory) {
913 const char* node_name = "SomeNodeName";
914 const char* op_name = "HostMemoryOp";
915 const char* device_name = "FakeDeviceName1";
916
917 REGISTER_OP(op_name)
918 .Input("input1: double")
919 .Input("input2: uint8")
920 .Output("output1: uint8")
921 .Output("output2: uint8")
922 .Attr("T: type");
923
924 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
925 MyComputeFunc(kernel, ctx);
926
927 TF_Status* status = TF_NewStatus();
928
929 TF_SetStatus(status, TF_OK, "");
930 EXPECT_EQ(false, TF_IsHostMemoryInput(ctx, 0, status));
931 EXPECT_EQ(TF_OK, TF_GetCode(status));
932
933 TF_SetStatus(status, TF_OK, "");
934 EXPECT_EQ(true, TF_IsHostMemoryInput(ctx, 1, status));
935 EXPECT_EQ(TF_OK, TF_GetCode(status));
936
937 TF_SetStatus(status, TF_OK, "");
938 EXPECT_EQ(true, TF_IsHostMemoryOutput(ctx, 0, status));
939 EXPECT_EQ(TF_OK, TF_GetCode(status));
940
941 TF_SetStatus(status, TF_OK, "");
942 EXPECT_EQ(false, TF_IsHostMemoryOutput(ctx, 1, status));
943 EXPECT_EQ(TF_OK, TF_GetCode(status));
944
945 TF_SetStatus(status, TF_OK, "");
946 TF_IsHostMemoryInput(ctx, -1, status);
947 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(status));
948
949 TF_SetStatus(status, TF_OK, "");
950 TF_IsHostMemoryInput(ctx, 2, status);
951 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(status));
952
953 TF_SetStatus(status, TF_OK, "");
954 TF_IsHostMemoryOutput(ctx, -1, status);
955 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(status));
956
957 TF_SetStatus(status, TF_OK, "");
958 TF_IsHostMemoryOutput(ctx, 2, status);
959 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(status));
960
961 TF_DeleteStatus(status);
962 };
963
964 TF_KernelBuilder* builder = TF_NewKernelBuilder(
965 op_name, device_name, &MyCreateFunc, my_compute_func, &MyDeleteFunc);
966 TF_KernelBuilder_HostMemory(builder, "input2");
967 TF_KernelBuilder_HostMemory(builder, "output1");
968 TF_Status* status = TF_NewStatus();
969 TF_RegisterKernelBuilder(node_name, builder, status);
970 EXPECT_EQ(TF_OK, TF_GetCode(status));
971
972 TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
973 EXPECT_EQ(TF_OK, TF_GetCode(status));
974 KernelList list;
975 list.ParseFromArray(buf->data, buf->length);
976 KernelList expected_proto;
977 protobuf::TextFormat::ParseFromString(
978 R"str(kernel {
979 op: "HostMemoryOp"
980 device_type: "FakeDeviceName1"
981 host_memory_arg: "input2"
982 host_memory_arg: "output1"
983 }
984 )str",
985 &expected_proto);
986 ASSERT_EQ(list.DebugString(), expected_proto.DebugString());
987
988 TF_DeleteBuffer(buf);
989 TF_DeleteStatus(status);
990 TF_DeleteKernelBuilder(builder);
991 ASSERT_TRUE(delete_called);
992 }
993
994 class DeviceKernelOpTest : public OpsTestBase {
995 protected:
SetupOp(const char * op_name,const char * node_name,void (* compute_func)(void *,TF_OpKernelContext *))996 void SetupOp(const char* op_name, const char* node_name,
997 void (*compute_func)(void*, TF_OpKernelContext*)) {
998 TF_KernelBuilder* builder = TF_NewKernelBuilder(
999 op_name, device_name_, nullptr, compute_func, nullptr);
1000 TF_Status* status = TF_NewStatus();
1001 TF_RegisterKernelBuilder(node_name, builder, status);
1002 EXPECT_EQ(TF_OK, TF_GetCode(status));
1003 TF_DeleteStatus(status);
1004
1005 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1006 std::unique_ptr<Device> device(
1007 DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0"));
1008 OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
1009 #endif
1010 TF_ASSERT_OK(NodeDefBuilder(op_name, op_name).Finalize(node_def()));
1011 TF_ASSERT_OK(InitOp());
1012 }
1013
1014 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1015 const char* device_name_ = tensorflow::DEVICE_GPU;
1016 #else
1017 const char* device_name_ = tensorflow::DEVICE_CPU;
1018 #endif
1019 };
1020
1021 // Validates that the tensor has shape and type corresponding to
1022 // dims and dtype.
1023 void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
1024 TF_DataType dtype);
1025
1026 // Copies data of length tensor_size_bytes from values to tensor.
1027 template <typename T>
1028 void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
1029 TF_OpKernelContext* ctx);
1030
1031 REGISTER_OP("StreamOp").Output("output1: float");
1032
TEST_F(DeviceKernelOpTest,TestStream)1033 TEST_F(DeviceKernelOpTest, TestStream) {
1034 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1035 TF_Status* s = TF_NewStatus();
1036 SP_Stream stream = TF_GetStream(ctx, s);
1037 // Stream is always null if device is not a pluggable device. More test
1038 // cases will be added when pluggable device mechanism is supported.
1039 EXPECT_EQ(stream, nullptr);
1040 EXPECT_NE(TF_OK, TF_GetCode(s));
1041 TF_DeleteStatus(s);
1042 };
1043
1044 SetupOp("StreamOp", "StreamOp", my_compute_func);
1045 TF_ASSERT_OK(RunOpKernel());
1046 }
1047
1048 REGISTER_OP("AllocateOutputOp1").Output("output1: float");
1049
TEST_F(DeviceKernelOpTest,TestAllocateOutputSizeOne)1050 TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
1051 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1052 // Allocate output
1053 TF_Status* s = TF_NewStatus();
1054 int64_t dim = 1;
1055 size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
1056 TF_Tensor* output = TF_AllocateOutput(
1057 /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
1058 /*num_dims=*/1, /*len=*/tensor_size_bytes, s);
1059 validate_tensor(output, &dim, 1, TF_FLOAT);
1060
1061 // Set output to 3
1062 float values[1] = {3.0f};
1063 set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
1064 TF_DeleteStatus(s);
1065 TF_DeleteTensor(output);
1066 };
1067
1068 SetupOp("AllocateOutputOp1", "AllocateOutput1", my_compute_func);
1069
1070 TF_ASSERT_OK(RunOpKernel());
1071 Tensor* output = GetOutput(0);
1072 EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
1073 output->DebugString(100));
1074 }
1075
1076 REGISTER_OP("AllocateOutputOp0").Output("output1: float");
1077
TEST_F(DeviceKernelOpTest,TestAllocateEmptyOutput)1078 TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) {
1079 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1080 TF_Status* s = TF_NewStatus();
1081 // Allocate empty output
1082 int64_t dim = 0;
1083 TF_Tensor* output = TF_AllocateOutput(
1084 /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
1085 /*num_dims=*/1, /*len=*/0, s);
1086 EXPECT_EQ(TF_OK, TF_GetCode(s));
1087 validate_tensor(output, &dim, 1, TF_FLOAT);
1088 TF_DeleteStatus(s);
1089 TF_DeleteTensor(output);
1090 };
1091
1092 SetupOp("AllocateOutputOp0", "AllocateOutput0", my_compute_func);
1093
1094 TF_ASSERT_OK(RunOpKernel());
1095 Tensor* output = GetOutput(0);
1096 EXPECT_EQ("Tensor<type: float shape: [0] values: >",
1097 output->DebugString(100));
1098 }
1099
1100 REGISTER_OP("AllocateOutputOp2x3").Output("output1: float");
1101
TEST_F(DeviceKernelOpTest,TestAllocateOutputSize2x3)1102 TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
1103 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1104 TF_Status* s = TF_NewStatus();
1105 // Allocate 2x3 output
1106 int64_t dim[2] = {2, 3};
1107 size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT) * 6;
1108 TF_Tensor* output = TF_AllocateOutput(
1109 /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
1110 /*num_dims=*/2, /*len=*/tensor_size_bytes, s);
1111 EXPECT_EQ(TF_OK, TF_GetCode(s));
1112 validate_tensor(output, dim, 2, TF_FLOAT);
1113
1114 // Set output to [1 2 3 4 5 6]
1115 float values[6] = {1, 2, 3, 4, 5, 6};
1116 set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
1117 TF_DeleteStatus(s);
1118 TF_DeleteTensor(output);
1119 };
1120
1121 SetupOp("AllocateOutputOp2x3", "AllocateOutput2x3", my_compute_func);
1122
1123 TF_ASSERT_OK(RunOpKernel());
1124 Tensor* output = GetOutput(0);
1125 EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
1126 output->DebugString(100));
1127 }
1128
1129 REGISTER_OP("AllocateTempOp1").Output("output1: float");
1130
TEST_F(DeviceKernelOpTest,TestAllocateTempSizeOne)1131 TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) {
1132 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1133 // Allocate scalar TF_Tensor
1134 TF_Status* s = TF_NewStatus();
1135 int64_t dim = 1;
1136 TF_AllocatorAttributes alloc_attrs;
1137 alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
1138 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1139 alloc_attrs.on_host = 0;
1140 #else
1141 alloc_attrs.on_host = 1;
1142 #endif
1143 TF_Tensor* output = TF_AllocateTemp(
1144 /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
1145 /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
1146 size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
1147 EXPECT_EQ(TF_OK, TF_GetCode(s));
1148 validate_tensor(output, &dim, 1, TF_FLOAT);
1149
1150 // Set TF_Tensor value to 3
1151 float values[1] = {3.0f};
1152 set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
1153 TF_SetOutput(ctx, 0, output, s);
1154 TF_DeleteStatus(s);
1155 TF_DeleteTensor(output);
1156 };
1157
1158 SetupOp("AllocateTempOp1", "AllocateTemp1", my_compute_func);
1159
1160 TF_ASSERT_OK(RunOpKernel());
1161 Tensor* output = GetOutput(0);
1162 EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
1163 output->DebugString(100));
1164 }
1165
1166 REGISTER_OP("AllocateTempOp0").Output("output1: float");
1167
TEST_F(DeviceKernelOpTest,TestAllocateTempEmpty)1168 TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) {
1169 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1170 TF_Status* s = TF_NewStatus();
1171 // Allocate empty TF_Tensor
1172 int64_t dim = 0;
1173 TF_AllocatorAttributes alloc_attrs;
1174 alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
1175 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1176 alloc_attrs.on_host = 0;
1177 #else
1178 alloc_attrs.on_host = 1;
1179 #endif
1180 TF_Tensor* output = TF_AllocateTemp(
1181 /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
1182 /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
1183 EXPECT_EQ(TF_OK, TF_GetCode(s));
1184 validate_tensor(output, &dim, 1, TF_FLOAT);
1185 TF_SetOutput(ctx, 0, output, s);
1186 TF_DeleteStatus(s);
1187 TF_DeleteTensor(output);
1188 };
1189
1190 SetupOp("AllocateTempOp0", "AllocateTemp0", my_compute_func);
1191
1192 TF_ASSERT_OK(RunOpKernel());
1193 Tensor* output = GetOutput(0);
1194 EXPECT_EQ("Tensor<type: float shape: [0] values: >",
1195 output->DebugString(100));
1196 }
1197
1198 REGISTER_OP("AllocateTempOp2x3").Output("output1: float");
1199
TEST_F(DeviceKernelOpTest,TestAllocateTempSize2x3)1200 TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
1201 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1202 TF_Status* s = TF_NewStatus();
1203 size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
1204 // Allocate 2x3 TF_Tensor
1205 int64_t dim[2] = {2, 3};
1206 TF_AllocatorAttributes alloc_attrs;
1207 alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
1208 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1209 alloc_attrs.on_host = 0;
1210 #else
1211 alloc_attrs.on_host = 1;
1212 #endif
1213 TF_Tensor* output = TF_AllocateTemp(
1214 /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim,
1215 /*num_dims=*/2, /*allocator_attributes*/ &alloc_attrs, s);
1216 EXPECT_EQ(TF_OK, TF_GetCode(s));
1217 validate_tensor(output, dim, 2, TF_FLOAT);
1218
1219 // Set TF_Tensor values to [1 2 3 4 5 6]
1220 float values[6] = {1, 2, 3, 4, 5, 6};
1221 set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
1222 TF_SetOutput(ctx, 0, output, s);
1223 TF_DeleteStatus(s);
1224 TF_DeleteTensor(output);
1225 };
1226
1227 SetupOp("AllocateTempOp2x3", "AllocateTempOp2x3", my_compute_func);
1228
1229 TF_ASSERT_OK(RunOpKernel());
1230 Tensor* output = GetOutput(0);
1231 EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
1232 output->DebugString(100));
1233 }
1234
1235 REGISTER_OP("DoNothingOp")
1236 .Input("input1: float")
1237 .Input("input2: float")
1238 .Attr("NumInput3: int >= 0")
1239 .Input("input3: NumInput3 * float")
1240 .Output("output1: float")
1241 .Attr("SomeDataTypeAttr: type");
1242
TEST_F(DeviceKernelOpTest,TestGetKernelInfo)1243 TEST_F(DeviceKernelOpTest, TestGetKernelInfo) {
1244 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1245 TF_Status* s = TF_NewStatus();
1246 int64_t dim[1] = {1};
1247 TF_AllocatorAttributes alloc_attrs;
1248 alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
1249 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1250 alloc_attrs.on_host = 0;
1251 #else
1252 alloc_attrs.on_host = 1;
1253 #endif
1254
1255 // Test if the C API returns expected strings.
1256 TF_StringView sv = TF_GetOpKernelName(ctx);
1257 EXPECT_STREQ(sv.data, "TestGetKernelInfoNode");
1258
1259 sv = TF_GetOpKernelRequestedInput(ctx, 0);
1260 EXPECT_STREQ(sv.data, "input1");
1261
1262 sv = TF_GetOpKernelRequestedInput(ctx, 1);
1263 EXPECT_STREQ(sv.data, "input2");
1264
1265 TF_InputRange_Args args;
1266 args.status = s;
1267 TF_InputRange(ctx, "input3", &args);
1268 EXPECT_EQ(TF_OK, TF_GetCode(s));
1269 EXPECT_EQ(args.start, 2);
1270 EXPECT_EQ(args.stop, 5);
1271
1272 TF_Tensor* output = TF_AllocateTemp(
1273 /*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim,
1274 /*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
1275 TF_SetOutput(ctx, 0, output, s);
1276 TF_DeleteStatus(s);
1277 TF_DeleteTensor(output);
1278 };
1279
1280 const char* node_name = "TestGetKernelInfoNode";
1281 const char* op_name = "DoNothingOp";
1282 const char* device_name = "FakeDeviceName";
1283 TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
1284 my_compute_func, nullptr);
1285
1286 TF_Status* status = TF_NewStatus();
1287 TF_RegisterKernelBuilder(node_name, builder, status);
1288 EXPECT_EQ(TF_OK, TF_GetCode(status));
1289 TF_DeleteStatus(status);
1290
1291 {
1292 OpKernelContext::Params p;
1293 DummyDevice dummy_device(nullptr);
1294 p.device = &dummy_device;
1295 AllocatorAttributes alloc_attrs;
1296 p.output_attr_array = &alloc_attrs;
1297
1298 gtl::InlinedVector<TensorValue, 4> inputs;
1299 Tensor t0(1.0f);
1300 Tensor t1(2.0f);
1301 Tensor t2_0(2.0f);
1302 Tensor t2_1(2.1f);
1303 Tensor t2_2(2.2f);
1304 inputs.emplace_back(&t0);
1305 inputs.emplace_back(&t1);
1306 inputs.emplace_back(&t2_0);
1307 inputs.emplace_back(&t2_1);
1308 inputs.emplace_back(&t2_2);
1309
1310 Status status;
1311 std::unique_ptr<OpKernel> kernel =
1312 GetFakeKernel2(device_name, op_name, node_name, &status);
1313 TF_EXPECT_OK(status);
1314 ASSERT_NE(nullptr, kernel.get());
1315
1316 p.op_kernel = kernel.get();
1317 p.inputs = inputs;
1318 OpKernelContext ctx(&p);
1319 kernel->Compute(&ctx);
1320 }
1321 }
1322
TEST_F(DeviceKernelOpTest,TestForwardInputOrAllocateOutput)1323 TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
1324 const char* node_name = "TestForwardInputOrAllocateOutputKernel";
1325 const char* op_name = "BazOp";
1326 const char* device_name = "FakeDeviceName";
1327
1328 REGISTER_OP(op_name)
1329 .Input("input1: float")
1330 .Input("input2: float")
1331 .Output("output1: float")
1332 .Attr("SomeDataTypeAttr: type");
1333
1334 // A kernel whose Compute function that forwards a scalar input to output
1335 auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
1336 TF_Status* s = TF_NewStatus();
1337 int candidate_input_indices[1] = {0};
1338 int forwarded_input;
1339 int64_t output_dims[1] = {};
1340 TF_Tensor* output = TF_ForwardInputOrAllocateOutput(
1341 /*context=*/ctx, candidate_input_indices,
1342 /*num_candidate_input_indices=*/1,
1343 /*output_index=*/0, output_dims, /*output_num_dims=*/0,
1344 &forwarded_input, /*status=*/s);
1345 EXPECT_EQ(TF_OK, TF_GetCode(s));
1346 EXPECT_EQ(forwarded_input, 0);
1347 EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
1348 EXPECT_EQ(0, TF_NumDims(output));
1349 TF_DeleteStatus(s);
1350 TF_DeleteTensor(output);
1351 };
1352
1353 TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
1354 my_compute_func, nullptr);
1355
1356 {
1357 TF_Status* status = TF_NewStatus();
1358 TF_RegisterKernelBuilder(node_name, builder, status);
1359 EXPECT_EQ(TF_OK, TF_GetCode(status));
1360 TF_DeleteStatus(status);
1361 }
1362
1363 {
1364 OpKernelContext::Params p;
1365 DummyDevice dummy_device(nullptr);
1366 p.device = &dummy_device;
1367 AllocatorAttributes alloc_attrs;
1368 p.output_attr_array = &alloc_attrs;
1369
1370 Tensor t(123.0f);
1371
1372 gtl::InlinedVector<TensorValue, 4> inputs;
1373 // GetFakeKernel requires a NodeDef with two inputs
1374 inputs.emplace_back(&t);
1375 inputs.emplace_back();
1376 p.inputs = inputs;
1377
1378 Status status;
1379 std::unique_ptr<OpKernel> kernel =
1380 GetFakeKernel(device_name, op_name, node_name, &status);
1381 TF_EXPECT_OK(status);
1382 ASSERT_NE(nullptr, kernel.get());
1383
1384 p.op_kernel = kernel.get();
1385 OpKernelContext ctx(&p);
1386 kernel->Compute(&ctx);
1387 ASSERT_EQ(123, ctx.mutable_output(0)->scalar<float>()());
1388 }
1389 }
1390
validate_tensor(TF_Tensor * tensor,int64_t * dims,int64_t num_dims,TF_DataType dtype)1391 void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
1392 TF_DataType dtype) {
1393 EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));
1394 EXPECT_EQ(num_dims, TF_NumDims(tensor));
1395 for (int i = 0; i < num_dims; ++i) {
1396 EXPECT_EQ(dims[i], TF_Dim(tensor, i));
1397 }
1398 }
1399
1400 template <typename T>
set_tensor_data(TF_Tensor * tensor,T * values,size_t tensor_size_bytes,TF_OpKernelContext * ctx)1401 void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
1402 TF_OpKernelContext* ctx) {
1403 T* data = reinterpret_cast<T*>(TF_TensorData(tensor));
1404 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1405 OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
1406 cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values,
1407 tensor_size_bytes);
1408 #else
1409 memcpy(data, values, tensor_size_bytes);
1410 #endif
1411 }
1412 } // namespace tensorflow
1413