1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_selector.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/refcount.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/protobuf/error_codes.pb.h"
36 #include "tfrt/cpu/core_runtime/null_op_handler.h" // from @tf_runtime
37 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
38 #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
39 #include "tfrt/host_context/diagnostic.h" // from @tf_runtime
40 #include "tfrt/host_context/host_allocator.h" // from @tf_runtime
41 #include "tfrt/support/forward_decls.h" // from @tf_runtime
42 #include "tfrt/support/ref_count.h" // from @tf_runtime
43 #include "tfrt/support/string_util.h" // from @tf_runtime
44
45 namespace tfrt {
46 namespace tf {
47 namespace {
48
49 using ::tensorflow::AbstractTensorHandle;
50 using ::tensorflow::Allocator;
51 using ::tensorflow::AllocatorAttributes;
52 using ::tensorflow::AttrBuilder;
53 using ::tensorflow::DataType;
54 using ::tensorflow::DEVICE_CPU;
55 using ::tensorflow::DeviceAttributes;
56 using ::tensorflow::DynamicDeviceMgr;
57 using ::tensorflow::EagerContext;
58 using ::tensorflow::ImmediateExecutionOperation;
59 using ::tensorflow::OpKernel;
60 using ::tensorflow::OpKernelConstruction;
61 using ::tensorflow::OpKernelContext;
62 using ::tensorflow::SessionOptions;
63 using ::tensorflow::Status;
64
65 constexpr char kFullCPU[] = "/job:a/replica:0/task:0/device:CPU:0";
66 constexpr char kFullGPU[] = "/job:a/replica:0/task:0/device:FakeGPU:0";
67
68 ////////////////////////////////////////////////////////////////////////////////
69 //
70 // Op, kernel to set up the environment.
71 //
72 // The Placer uses information about the op (input types),
73 // kernel (device constraints). To avoid depending on the full runtime, we
74 // define dummy implementations of these, and register them with the
75 // runtime.
76 //
77 ////////////////////////////////////////////////////////////////////////////////
78
79 // A dummy OpKernel that is used to register ops on different devices.
80 class DummyOp : public OpKernel {
81 public:
DummyOp(OpKernelConstruction * context)82 explicit DummyOp(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)83 void Compute(OpKernelContext* context) override {}
84 };
85
86 // Register the following ops so they can be added to a Graph, and
87 // kernels so that they can be placed on particular device types.
88 REGISTER_OP("InvalidOp").Output("o: Ref(float)");
89
90 REGISTER_OP("TestOp").Output("o: Ref(float)");
91 REGISTER_KERNEL_BUILDER(Name("TestOp").Device(DEVICE_CPU).Priority(1), DummyOp);
92 REGISTER_KERNEL_BUILDER(Name("TestOp").Device("FakeGPU").Priority(2), DummyOp);
93
CreateDevice(const char * type,const char * name)94 static tensorflow::Device* CreateDevice(const char* type, const char* name) {
95 class FakeDevice : public tensorflow::Device {
96 public:
97 explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
98 Status Sync() override { return ::tensorflow::OkStatus(); }
99 Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
100 };
101 DeviceAttributes attr;
102 attr.set_name(name);
103 attr.set_device_type(type);
104 return new FakeDevice(attr);
105 }
106
107 class FakeTensorHandle : public tensorflow::ImmediateExecutionTensorHandle {
108 public:
FakeTensorHandle(string_view device_name,tensorflow::DataType dtype)109 explicit FakeTensorHandle(string_view device_name, tensorflow::DataType dtype)
110 : ImmediateExecutionTensorHandle(kTfrt),
111 device_name_(device_name),
112 dtype_(dtype) {}
113
Release()114 void Release() override { Unref(); }
115
DataType() const116 tensorflow::DataType DataType() const override { return dtype_; }
Shape(tensorflow::PartialTensorShape * shape) const117 Status Shape(tensorflow::PartialTensorShape* shape) const override {
118 int64_t dim_sizes[] = {1};
119 return tensorflow::PartialTensorShape::MakePartialShape(dim_sizes, 1,
120 shape);
121 }
NumDims(int * num_dims) const122 Status NumDims(int* num_dims) const override {
123 *num_dims = 1;
124 return ::tensorflow::OkStatus();
125 }
NumElements(int64_t * num_elements) const126 Status NumElements(int64_t* num_elements) const override {
127 *num_elements = 1;
128 return ::tensorflow::OkStatus();
129 }
Dim(int dim_index,int64_t * dim) const130 Status Dim(int dim_index, int64_t* dim) const override {
131 llvm_unreachable("unimplemented method.");
132 }
133
DeviceName(Status * status) const134 const char* DeviceName(Status* status) const override {
135 return device_name_.c_str();
136 }
BackingDeviceName(Status * status) const137 const char* BackingDeviceName(Status* status) const override {
138 llvm_unreachable("unimplemented method.");
139 }
DeviceType(Status * status) const140 const char* DeviceType(Status* status) const override {
141 llvm_unreachable("unimplemented method.");
142 }
DeviceId(Status * status) const143 int DeviceId(Status* status) const override {
144 llvm_unreachable("unimplemented method.");
145 }
Resolve(Status * status)146 tensorflow::AbstractTensorInterface* Resolve(Status* status) override {
147 llvm_unreachable("unimplemented method.");
148 }
Copy()149 ImmediateExecutionTensorHandle* Copy() override {
150 Ref();
151 return this;
152 }
153
classof(const AbstractTensorHandle * ptr)154 static bool classof(const AbstractTensorHandle* ptr) { return true; }
155
156 private:
157 std::string device_name_;
158 tensorflow::DataType dtype_;
159 };
160
161 class FakeOperation : public ImmediateExecutionOperation {
162 public:
FakeOperation()163 explicit FakeOperation() : ImmediateExecutionOperation(kTfrt) {}
~FakeOperation()164 ~FakeOperation() override {}
165
Release()166 void Release() override { delete this; }
167
Clear()168 void Clear() override { args_.clear(); }
169
GetContext() const170 tensorflow::ImmediateExecutionContext* GetContext() const override {
171 return nullptr;
172 }
173
HasCustomDeviceInput() const174 bool HasCustomDeviceInput() const override { return false; }
175
Reset(const char * op,const char * raw_device_name)176 Status Reset(const char* op, const char* raw_device_name) override {
177 op_name_ = op;
178 device_name_ = raw_device_name;
179 attrs_.Reset(op);
180 args_.clear();
181 return ::tensorflow::OkStatus();
182 }
Name() const183 const std::string& Name() const override { return op_name_; }
DeviceName() const184 const std::string& DeviceName() const override { return device_name_; }
SetDeviceName(const char * name)185 tensorflow::Status SetDeviceName(const char* name) override {
186 device_name_ = name;
187 return ::tensorflow::OkStatus();
188 }
189
AddInput(AbstractTensorHandle * input)190 Status AddInput(AbstractTensorHandle* input) override {
191 input->Ref();
192 args_.push_back(tensorflow::core::RefCountPtr<FakeTensorHandle>(
193 static_cast<FakeTensorHandle*>(input)));
194 attrs_.NumInputs(args_.size());
195 return ::tensorflow::OkStatus();
196 }
SetInput(size_t index,tensorflow::ImmediateExecutionTensorHandle * input)197 Status SetInput(size_t index,
198 tensorflow::ImmediateExecutionTensorHandle* input) override {
199 llvm_unreachable("unimplemented method.");
200 }
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)201 Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override {
202 llvm_unreachable("unimplemented method.");
203 }
GetInputs() const204 absl::Span<tensorflow::ImmediateExecutionTensorHandle* const> GetInputs()
205 const override {
206 return absl::MakeSpan(
207 reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle* const*>(
208 args_.data()),
209 args_.size());
210 }
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)211 Status Execute(absl::Span<AbstractTensorHandle*> retvals,
212 int* num_retvals) override {
213 llvm_unreachable("unimplemented method.");
214 }
OpDef() const215 const tensorflow::OpDef* OpDef() const override {
216 llvm_unreachable("unimplemented method.");
217 }
GetOpAttrs() const218 const tensorflow::AbstractOpAttrs* GetOpAttrs() const override {
219 llvm_unreachable("unimplemented method.");
220 }
AddAttrs(const tensorflow::AbstractOpAttrs * op_attrs)221 void AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) override {
222 llvm_unreachable("unimplemented method.");
223 }
SetAttrString(const char * attr_name,const char * data,size_t length)224 Status SetAttrString(const char* attr_name, const char* data,
225 size_t length) override {
226 llvm_unreachable("unimplemented method.");
227 }
SetAttrInt(const char * attr_name,int64_t value)228 Status SetAttrInt(const char* attr_name, int64_t value) override {
229 llvm_unreachable("unimplemented method.");
230 }
SetAttrFloat(const char * attr_name,float value)231 Status SetAttrFloat(const char* attr_name, float value) override {
232 llvm_unreachable("unimplemented method.");
233 }
SetAttrBool(const char * attr_name,bool value)234 Status SetAttrBool(const char* attr_name, bool value) override {
235 llvm_unreachable("unimplemented method.");
236 }
SetAttrType(const char * attr_name,tensorflow::DataType value)237 Status SetAttrType(const char* attr_name,
238 tensorflow::DataType value) override {
239 llvm_unreachable("unimplemented method.");
240 }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)241 Status SetAttrShape(const char* attr_name, const int64_t* dims,
242 const int num_dims) override {
243 llvm_unreachable("unimplemented method.");
244 }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)245 Status SetAttrFunction(const char* attr_name,
246 const AbstractOperation* value) override {
247 llvm_unreachable("unimplemented method.");
248 }
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)249 Status SetAttrFunctionName(const char* attr_name, const char* data,
250 size_t length) override {
251 llvm_unreachable("unimplemented method.");
252 }
SetAttrTensor(const char * attr_name,tensorflow::AbstractTensorInterface * tensor)253 Status SetAttrTensor(const char* attr_name,
254 tensorflow::AbstractTensorInterface* tensor) override {
255 llvm_unreachable("unimplemented method.");
256 }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)257 Status SetAttrStringList(const char* attr_name, const void* const* values,
258 const size_t* lengths, int num_values) override {
259 llvm_unreachable("unimplemented method.");
260 }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)261 Status SetAttrFloatList(const char* attr_name, const float* values,
262 int num_values) override {
263 llvm_unreachable("unimplemented method.");
264 }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)265 Status SetAttrIntList(const char* attr_name, const int64_t* values,
266 int num_values) override {
267 llvm_unreachable("unimplemented method.");
268 }
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)269 Status SetAttrTypeList(const char* attr_name,
270 const tensorflow::DataType* values,
271 int num_values) override {
272 llvm_unreachable("unimplemented method.");
273 }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)274 Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
275 int num_values) override {
276 llvm_unreachable("unimplemented method.");
277 }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)278 Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
279 const int* num_dims, int num_values) override {
280 llvm_unreachable("unimplemented method.");
281 }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)282 Status SetAttrFunctionList(
283 const char* attr_name,
284 absl::Span<const AbstractOperation*> values) override {
285 llvm_unreachable("unimplemented method.");
286 }
287
InputLength(const char * input_name,int * length)288 Status InputLength(const char* input_name, int* length) override {
289 llvm_unreachable("unimplemented method.");
290 }
OutputLength(const char * output_name,int * length)291 Status OutputLength(const char* output_name, int* length) override {
292 llvm_unreachable("unimplemented method.");
293 }
294
SetCancellationManager(tensorflow::CancellationManager * cancellation_manager)295 void SetCancellationManager(
296 tensorflow::CancellationManager* cancellation_manager) override {
297 llvm_unreachable("unimplemented method.");
298 }
299
SetStackTrace(tensorflow::ManagedStackTrace stack_trace)300 void SetStackTrace(tensorflow::ManagedStackTrace stack_trace) override {
301 llvm_unreachable("unimplemented method.");
302 }
303
GetStackTrace()304 absl::optional<tensorflow::ManagedStackTrace> GetStackTrace() override {
305 llvm_unreachable("unimplemented method.");
306 }
307
SetStepId(int64_t step_id)308 void SetStepId(int64_t step_id) override {
309 llvm_unreachable("unimplemented method.");
310 }
311
classof(const AbstractOperation * ptr)312 static bool classof(const AbstractOperation* ptr) { return true; }
313
GetAttrs()314 AttrBuilder* GetAttrs() { return &attrs_; }
315
316 private:
317 std::string op_name_;
318 std::string device_name_;
319 llvm::SmallVector<tensorflow::core::RefCountPtr<FakeTensorHandle>, 8> args_;
320 AttrBuilder attrs_;
321 };
322
CreateCoreRuntime()323 static std::unique_ptr<CoreRuntime> CreateCoreRuntime() {
324 auto diag_handler = [](const DecodedDiagnostic& diag) {
325 LOG(ERROR) << "Encountered runtime error: " << diag.message << "\n";
326 };
327 auto corert =
328 CoreRuntime::Create(diag_handler, tfrt::CreateMallocAllocator(),
329 tfrt::CreateMultiThreadedWorkQueue(
330 /*num_threads=*/4, /*num_blocking_threads=*/64),
331 kFullCPU);
332
333 assert(corert);
334 return std::move(*corert);
335 }
336
337 class SelectorTest : public ::testing::Test {
338 public:
SelectorTest()339 SelectorTest() {
340 device_manager_ = new DynamicDeviceMgr();
341 std::vector<std::unique_ptr<tensorflow::Device>> added_devices;
342 SessionOptions opts;
343
344 // Have to use real CPU device. Other, ctx->HostCPU() will return invalid
345 // device.
346 added_devices.emplace_back(CreateDevice(tensorflow::DEVICE_CPU, kFullCPU));
347 added_devices.emplace_back(CreateDevice("FakeGPU", kFullGPU));
348
349 TF_CHECK_OK(device_manager_->AddDevices(std::move(added_devices)));
350
351 SessionOptions options;
352 options.config.set_log_device_placement(true);
353 options.config.set_allow_soft_placement(true);
354 eager_context_ = new EagerContext(
355 options,
356 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
357 /* async */ false, device_manager_,
358 /* device_mgr_owned */ false, /* rendezvous */ nullptr,
359 /* cluster_flr */ nullptr);
360 corert_ = CreateCoreRuntime();
361 fallback_op_handler_ = CreateOpHandler();
362 cpu_op_handler_ = CreateOpHandler();
363 gpu_op_handler_ = CreateOpHandler();
364 corert_->RegisterOpHandler(kFullCPU, cpu_op_handler_);
365 corert_->RegisterOpHandler(kFullGPU, gpu_op_handler_);
366
367 selector_ = std::make_unique<EagerOpHandlerSelector>(
368 corert_.get(), eager_context_, fallback_op_handler_,
369 /*pin_small_ops_to_cpu=*/true);
370 }
371
~SelectorTest()372 ~SelectorTest() override {
373 delete device_manager_;
374 if (eager_context_) {
375 eager_context_->Unref();
376 }
377 }
378
selector()379 EagerOpHandlerSelector* selector() { return selector_.get(); }
380
Init()381 void Init() {}
382
383 protected:
CreateOpHandler()384 OpHandler* CreateOpHandler() {
385 auto expected_op_handler = tfrt::CreateNullOpHandler(corert_.get());
386 assert(expected_op_handler);
387 return std::move(expected_op_handler.get());
388 }
389
390 DynamicDeviceMgr* device_manager_;
391 EagerContext* eager_context_;
392 std::unique_ptr<CoreRuntime> corert_;
393 OpHandler* fallback_op_handler_;
394 OpHandler* cpu_op_handler_;
395 OpHandler* gpu_op_handler_;
396 std::unique_ptr<EagerOpHandlerSelector> selector_;
397 };
398
TEST_F(SelectorTest,PinSmallOpToCpuTest)399 TEST_F(SelectorTest, PinSmallOpToCpuTest) {
400 auto op = std::make_unique<FakeOperation>();
401 tensorflow::core::RefCountPtr<FakeTensorHandle> cpu_tensor(
402 new FakeTensorHandle(kFullCPU, tensorflow::DT_INT32));
403 tensorflow::core::RefCountPtr<FakeTensorHandle> gpu_tensor(
404 new FakeTensorHandle(kFullGPU, tensorflow::DT_INT32));
405
406 tensorflow::Status s;
407 TF_ASSERT_OK(op->Reset("TestOp", kFullGPU));
408 TF_ASSERT_OK(op->AddInput(cpu_tensor.get()));
409 OpHandler* op_handler = nullptr;
410 s = selector()->SelectFromArguments(*op, &op_handler);
411 ASSERT_EQ(s, ::tensorflow::OkStatus());
412 ASSERT_TRUE(static_cast<bool>(op_handler));
413 ASSERT_EQ(op_handler, cpu_op_handler_);
414
415 op_handler = nullptr;
416 TF_ASSERT_OK(op->Reset("TestOp", kFullGPU));
417 TF_ASSERT_OK(op->AddInput(gpu_tensor.get()));
418 s = selector()->SelectFromArguments(*op, &op_handler);
419 ASSERT_EQ(s, ::tensorflow::OkStatus());
420 ASSERT_FALSE(static_cast<bool>(op_handler));
421 s = selector()->SelectFromNodeDef(*op, &op->GetAttrs()->BuildNodeDef(),
422 &op_handler);
423 ASSERT_EQ(s, ::tensorflow::OkStatus());
424 ASSERT_TRUE(static_cast<bool>(op_handler));
425 ASSERT_EQ(op_handler, gpu_op_handler_);
426 }
427
TEST_F(SelectorTest,PinResourceTest)428 TEST_F(SelectorTest, PinResourceTest) {
429 auto op = std::make_unique<FakeOperation>();
430 tensorflow::core::RefCountPtr<FakeTensorHandle> cpu_tensor(
431 new FakeTensorHandle(kFullCPU, tensorflow::DT_RESOURCE));
432 tensorflow::core::RefCountPtr<FakeTensorHandle> gpu_tensor(
433 new FakeTensorHandle(kFullGPU, tensorflow::DT_RESOURCE));
434
435 tensorflow::Status s;
436 TF_ASSERT_OK(op->Reset("TestOp", kFullGPU));
437 TF_ASSERT_OK(op->AddInput(cpu_tensor.get()));
438 OpHandler* op_handler = nullptr;
439 s = selector()->SelectFromArguments(*op, &op_handler);
440 ASSERT_EQ(s, ::tensorflow::OkStatus());
441 ASSERT_TRUE(static_cast<bool>(op_handler));
442 ASSERT_EQ(op_handler, cpu_op_handler_);
443
444 op_handler = nullptr;
445 TF_ASSERT_OK(op->Reset("TestOp", kFullCPU));
446 TF_ASSERT_OK(op->AddInput(gpu_tensor.get()));
447 s = selector()->SelectFromArguments(*op, &op_handler);
448 ASSERT_EQ(s, ::tensorflow::OkStatus());
449 ASSERT_TRUE(static_cast<bool>(op_handler));
450 ASSERT_EQ(op_handler, gpu_op_handler_);
451 }
452
TEST_F(SelectorTest,InvalidDeviceNameTest)453 TEST_F(SelectorTest, InvalidDeviceNameTest) {
454 auto op = std::make_unique<FakeOperation>();
455
456 TF_ASSERT_OK(op->Reset("TestOp", "invalid_device_name"));
457
458 tensorflow::Status s;
459 OpHandler* op_handler = nullptr;
460 s = selector()->SelectFromNodeDef(*op, &op->GetAttrs()->BuildNodeDef(),
461 &op_handler);
462 ASSERT_EQ(s.code(), tensorflow::error::INVALID_ARGUMENT);
463 ASSERT_FALSE(static_cast<bool>(op_handler));
464 EXPECT_TRUE(
465 absl::StrContains(s.error_message(), "Failed to parse device name"));
466 }
467
TEST_F(SelectorTest,SoftPlacementTest)468 TEST_F(SelectorTest, SoftPlacementTest) {
469 auto op = std::make_unique<FakeOperation>();
470
471 TF_ASSERT_OK(op->Reset("TestOp", "/device:FakeGPU:99"));
472 tensorflow::Status s;
473 OpHandler* op_handler = nullptr;
474 s = selector()->SelectFromNodeDef(*op, &op->GetAttrs()->BuildNodeDef(),
475 &op_handler);
476 ASSERT_EQ(s, ::tensorflow::OkStatus());
477 ASSERT_TRUE(static_cast<bool>(op_handler)) << StrCat(s.error_message());
478 ASSERT_EQ(op_handler, gpu_op_handler_);
479 }
480
TEST_F(SelectorTest,HigherPriorityDeviceTest)481 TEST_F(SelectorTest, HigherPriorityDeviceTest) {
482 auto op = std::make_unique<FakeOperation>();
483
484 tensorflow::Status s;
485 TF_ASSERT_OK(op->Reset("TestOp", ""));
486 OpHandler* op_handler = nullptr;
487 s = selector()->SelectFromNodeDef(*op, &op->GetAttrs()->BuildNodeDef(),
488 &op_handler);
489 ASSERT_EQ(s, ::tensorflow::OkStatus());
490 ASSERT_TRUE(static_cast<bool>(op_handler));
491 ASSERT_EQ(op_handler, gpu_op_handler_);
492 }
493
494 } // namespace
495 } // namespace tf
496 } // namespace tfrt
497