1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tfrt/eager/core_runtime/op_handler_selector.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
24 #include "tensorflow/core/common_runtime/eager/context.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/refcount.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/protobuf/error_codes.pb.h"
36 #include "tfrt/cpu/core_runtime/null_op_handler.h"  // from @tf_runtime
37 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
38 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
39 #include "tfrt/host_context/diagnostic.h"  // from @tf_runtime
40 #include "tfrt/host_context/host_allocator.h"  // from @tf_runtime
41 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
42 #include "tfrt/support/ref_count.h"  // from @tf_runtime
43 #include "tfrt/support/string_util.h"  // from @tf_runtime
44 
45 namespace tfrt {
46 namespace tf {
47 namespace {
48 
49 using ::tensorflow::AbstractTensorHandle;
50 using ::tensorflow::Allocator;
51 using ::tensorflow::AllocatorAttributes;
52 using ::tensorflow::AttrBuilder;
53 using ::tensorflow::DataType;
54 using ::tensorflow::DEVICE_CPU;
55 using ::tensorflow::DeviceAttributes;
56 using ::tensorflow::DynamicDeviceMgr;
57 using ::tensorflow::EagerContext;
58 using ::tensorflow::ImmediateExecutionOperation;
59 using ::tensorflow::OpKernel;
60 using ::tensorflow::OpKernelConstruction;
61 using ::tensorflow::OpKernelContext;
62 using ::tensorflow::SessionOptions;
63 using ::tensorflow::Status;
64 
65 constexpr char kFullCPU[] = "/job:a/replica:0/task:0/device:CPU:0";
66 constexpr char kFullGPU[] = "/job:a/replica:0/task:0/device:FakeGPU:0";
67 
68 ////////////////////////////////////////////////////////////////////////////////
69 //
70 // Op, kernel to set up the environment.
71 //
72 // The Placer uses information about the op (input types),
73 // kernel (device constraints). To avoid depending on the full runtime, we
74 // define dummy implementations of these, and register them with the
75 // runtime.
76 //
77 ////////////////////////////////////////////////////////////////////////////////
78 
79 // A dummy OpKernel that is used to register ops on different devices.
80 class DummyOp : public OpKernel {
81  public:
DummyOp(OpKernelConstruction * context)82   explicit DummyOp(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)83   void Compute(OpKernelContext* context) override {}
84 };
85 
86 // Register the following ops so they can be added to a Graph, and
87 // kernels so that they can be placed on particular device types.
88 REGISTER_OP("InvalidOp").Output("o: Ref(float)");
89 
90 REGISTER_OP("TestOp").Output("o: Ref(float)");
91 REGISTER_KERNEL_BUILDER(Name("TestOp").Device(DEVICE_CPU).Priority(1), DummyOp);
92 REGISTER_KERNEL_BUILDER(Name("TestOp").Device("FakeGPU").Priority(2), DummyOp);
93 
CreateDevice(const char * type,const char * name)94 static tensorflow::Device* CreateDevice(const char* type, const char* name) {
95   class FakeDevice : public tensorflow::Device {
96    public:
97     explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
98     Status Sync() override { return ::tensorflow::OkStatus(); }
99     Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
100   };
101   DeviceAttributes attr;
102   attr.set_name(name);
103   attr.set_device_type(type);
104   return new FakeDevice(attr);
105 }
106 
107 class FakeTensorHandle : public tensorflow::ImmediateExecutionTensorHandle {
108  public:
FakeTensorHandle(string_view device_name,tensorflow::DataType dtype)109   explicit FakeTensorHandle(string_view device_name, tensorflow::DataType dtype)
110       : ImmediateExecutionTensorHandle(kTfrt),
111         device_name_(device_name),
112         dtype_(dtype) {}
113 
Release()114   void Release() override { Unref(); }
115 
DataType() const116   tensorflow::DataType DataType() const override { return dtype_; }
Shape(tensorflow::PartialTensorShape * shape) const117   Status Shape(tensorflow::PartialTensorShape* shape) const override {
118     int64_t dim_sizes[] = {1};
119     return tensorflow::PartialTensorShape::MakePartialShape(dim_sizes, 1,
120                                                             shape);
121   }
NumDims(int * num_dims) const122   Status NumDims(int* num_dims) const override {
123     *num_dims = 1;
124     return ::tensorflow::OkStatus();
125   }
NumElements(int64_t * num_elements) const126   Status NumElements(int64_t* num_elements) const override {
127     *num_elements = 1;
128     return ::tensorflow::OkStatus();
129   }
Dim(int dim_index,int64_t * dim) const130   Status Dim(int dim_index, int64_t* dim) const override {
131     llvm_unreachable("unimplemented method.");
132   }
133 
DeviceName(Status * status) const134   const char* DeviceName(Status* status) const override {
135     return device_name_.c_str();
136   }
BackingDeviceName(Status * status) const137   const char* BackingDeviceName(Status* status) const override {
138     llvm_unreachable("unimplemented method.");
139   }
DeviceType(Status * status) const140   const char* DeviceType(Status* status) const override {
141     llvm_unreachable("unimplemented method.");
142   }
DeviceId(Status * status) const143   int DeviceId(Status* status) const override {
144     llvm_unreachable("unimplemented method.");
145   }
Resolve(Status * status)146   tensorflow::AbstractTensorInterface* Resolve(Status* status) override {
147     llvm_unreachable("unimplemented method.");
148   }
Copy()149   ImmediateExecutionTensorHandle* Copy() override {
150     Ref();
151     return this;
152   }
153 
classof(const AbstractTensorHandle * ptr)154   static bool classof(const AbstractTensorHandle* ptr) { return true; }
155 
156  private:
157   std::string device_name_;
158   tensorflow::DataType dtype_;
159 };
160 
161 class FakeOperation : public ImmediateExecutionOperation {
162  public:
FakeOperation()163   explicit FakeOperation() : ImmediateExecutionOperation(kTfrt) {}
~FakeOperation()164   ~FakeOperation() override {}
165 
Release()166   void Release() override { delete this; }
167 
Clear()168   void Clear() override { args_.clear(); }
169 
GetContext() const170   tensorflow::ImmediateExecutionContext* GetContext() const override {
171     return nullptr;
172   }
173 
HasCustomDeviceInput() const174   bool HasCustomDeviceInput() const override { return false; }
175 
Reset(const char * op,const char * raw_device_name)176   Status Reset(const char* op, const char* raw_device_name) override {
177     op_name_ = op;
178     device_name_ = raw_device_name;
179     attrs_.Reset(op);
180     args_.clear();
181     return ::tensorflow::OkStatus();
182   }
Name() const183   const std::string& Name() const override { return op_name_; }
DeviceName() const184   const std::string& DeviceName() const override { return device_name_; }
SetDeviceName(const char * name)185   tensorflow::Status SetDeviceName(const char* name) override {
186     device_name_ = name;
187     return ::tensorflow::OkStatus();
188   }
189 
AddInput(AbstractTensorHandle * input)190   Status AddInput(AbstractTensorHandle* input) override {
191     input->Ref();
192     args_.push_back(tensorflow::core::RefCountPtr<FakeTensorHandle>(
193         static_cast<FakeTensorHandle*>(input)));
194     attrs_.NumInputs(args_.size());
195     return ::tensorflow::OkStatus();
196   }
SetInput(size_t index,tensorflow::ImmediateExecutionTensorHandle * input)197   Status SetInput(size_t index,
198                   tensorflow::ImmediateExecutionTensorHandle* input) override {
199     llvm_unreachable("unimplemented method.");
200   }
AddInputList(absl::Span<AbstractTensorHandle * const> inputs)201   Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override {
202     llvm_unreachable("unimplemented method.");
203   }
GetInputs() const204   absl::Span<tensorflow::ImmediateExecutionTensorHandle* const> GetInputs()
205       const override {
206     return absl::MakeSpan(
207         reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle* const*>(
208             args_.data()),
209         args_.size());
210   }
Execute(absl::Span<AbstractTensorHandle * > retvals,int * num_retvals)211   Status Execute(absl::Span<AbstractTensorHandle*> retvals,
212                  int* num_retvals) override {
213     llvm_unreachable("unimplemented method.");
214   }
OpDef() const215   const tensorflow::OpDef* OpDef() const override {
216     llvm_unreachable("unimplemented method.");
217   }
GetOpAttrs() const218   const tensorflow::AbstractOpAttrs* GetOpAttrs() const override {
219     llvm_unreachable("unimplemented method.");
220   }
AddAttrs(const tensorflow::AbstractOpAttrs * op_attrs)221   void AddAttrs(const tensorflow::AbstractOpAttrs* op_attrs) override {
222     llvm_unreachable("unimplemented method.");
223   }
SetAttrString(const char * attr_name,const char * data,size_t length)224   Status SetAttrString(const char* attr_name, const char* data,
225                        size_t length) override {
226     llvm_unreachable("unimplemented method.");
227   }
SetAttrInt(const char * attr_name,int64_t value)228   Status SetAttrInt(const char* attr_name, int64_t value) override {
229     llvm_unreachable("unimplemented method.");
230   }
SetAttrFloat(const char * attr_name,float value)231   Status SetAttrFloat(const char* attr_name, float value) override {
232     llvm_unreachable("unimplemented method.");
233   }
SetAttrBool(const char * attr_name,bool value)234   Status SetAttrBool(const char* attr_name, bool value) override {
235     llvm_unreachable("unimplemented method.");
236   }
SetAttrType(const char * attr_name,tensorflow::DataType value)237   Status SetAttrType(const char* attr_name,
238                      tensorflow::DataType value) override {
239     llvm_unreachable("unimplemented method.");
240   }
SetAttrShape(const char * attr_name,const int64_t * dims,const int num_dims)241   Status SetAttrShape(const char* attr_name, const int64_t* dims,
242                       const int num_dims) override {
243     llvm_unreachable("unimplemented method.");
244   }
SetAttrFunction(const char * attr_name,const AbstractOperation * value)245   Status SetAttrFunction(const char* attr_name,
246                          const AbstractOperation* value) override {
247     llvm_unreachable("unimplemented method.");
248   }
SetAttrFunctionName(const char * attr_name,const char * data,size_t length)249   Status SetAttrFunctionName(const char* attr_name, const char* data,
250                              size_t length) override {
251     llvm_unreachable("unimplemented method.");
252   }
SetAttrTensor(const char * attr_name,tensorflow::AbstractTensorInterface * tensor)253   Status SetAttrTensor(const char* attr_name,
254                        tensorflow::AbstractTensorInterface* tensor) override {
255     llvm_unreachable("unimplemented method.");
256   }
SetAttrStringList(const char * attr_name,const void * const * values,const size_t * lengths,int num_values)257   Status SetAttrStringList(const char* attr_name, const void* const* values,
258                            const size_t* lengths, int num_values) override {
259     llvm_unreachable("unimplemented method.");
260   }
SetAttrFloatList(const char * attr_name,const float * values,int num_values)261   Status SetAttrFloatList(const char* attr_name, const float* values,
262                           int num_values) override {
263     llvm_unreachable("unimplemented method.");
264   }
SetAttrIntList(const char * attr_name,const int64_t * values,int num_values)265   Status SetAttrIntList(const char* attr_name, const int64_t* values,
266                         int num_values) override {
267     llvm_unreachable("unimplemented method.");
268   }
SetAttrTypeList(const char * attr_name,const tensorflow::DataType * values,int num_values)269   Status SetAttrTypeList(const char* attr_name,
270                          const tensorflow::DataType* values,
271                          int num_values) override {
272     llvm_unreachable("unimplemented method.");
273   }
SetAttrBoolList(const char * attr_name,const unsigned char * values,int num_values)274   Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
275                          int num_values) override {
276     llvm_unreachable("unimplemented method.");
277   }
SetAttrShapeList(const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values)278   Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
279                           const int* num_dims, int num_values) override {
280     llvm_unreachable("unimplemented method.");
281   }
SetAttrFunctionList(const char * attr_name,absl::Span<const AbstractOperation * > values)282   Status SetAttrFunctionList(
283       const char* attr_name,
284       absl::Span<const AbstractOperation*> values) override {
285     llvm_unreachable("unimplemented method.");
286   }
287 
InputLength(const char * input_name,int * length)288   Status InputLength(const char* input_name, int* length) override {
289     llvm_unreachable("unimplemented method.");
290   }
OutputLength(const char * output_name,int * length)291   Status OutputLength(const char* output_name, int* length) override {
292     llvm_unreachable("unimplemented method.");
293   }
294 
SetCancellationManager(tensorflow::CancellationManager * cancellation_manager)295   void SetCancellationManager(
296       tensorflow::CancellationManager* cancellation_manager) override {
297     llvm_unreachable("unimplemented method.");
298   }
299 
SetStackTrace(tensorflow::ManagedStackTrace stack_trace)300   void SetStackTrace(tensorflow::ManagedStackTrace stack_trace) override {
301     llvm_unreachable("unimplemented method.");
302   }
303 
GetStackTrace()304   absl::optional<tensorflow::ManagedStackTrace> GetStackTrace() override {
305     llvm_unreachable("unimplemented method.");
306   }
307 
SetStepId(int64_t step_id)308   void SetStepId(int64_t step_id) override {
309     llvm_unreachable("unimplemented method.");
310   }
311 
classof(const AbstractOperation * ptr)312   static bool classof(const AbstractOperation* ptr) { return true; }
313 
GetAttrs()314   AttrBuilder* GetAttrs() { return &attrs_; }
315 
316  private:
317   std::string op_name_;
318   std::string device_name_;
319   llvm::SmallVector<tensorflow::core::RefCountPtr<FakeTensorHandle>, 8> args_;
320   AttrBuilder attrs_;
321 };
322 
CreateCoreRuntime()323 static std::unique_ptr<CoreRuntime> CreateCoreRuntime() {
324   auto diag_handler = [](const DecodedDiagnostic& diag) {
325     LOG(ERROR) << "Encountered runtime error: " << diag.message << "\n";
326   };
327   auto corert =
328       CoreRuntime::Create(diag_handler, tfrt::CreateMallocAllocator(),
329                           tfrt::CreateMultiThreadedWorkQueue(
330                               /*num_threads=*/4, /*num_blocking_threads=*/64),
331                           kFullCPU);
332 
333   assert(corert);
334   return std::move(*corert);
335 }
336 
337 class SelectorTest : public ::testing::Test {
338  public:
SelectorTest()339   SelectorTest() {
340     device_manager_ = new DynamicDeviceMgr();
341     std::vector<std::unique_ptr<tensorflow::Device>> added_devices;
342     SessionOptions opts;
343 
344     // Have to use real CPU device. Other, ctx->HostCPU() will return invalid
345     // device.
346     added_devices.emplace_back(CreateDevice(tensorflow::DEVICE_CPU, kFullCPU));
347     added_devices.emplace_back(CreateDevice("FakeGPU", kFullGPU));
348 
349     TF_CHECK_OK(device_manager_->AddDevices(std::move(added_devices)));
350 
351     SessionOptions options;
352     options.config.set_log_device_placement(true);
353     options.config.set_allow_soft_placement(true);
354     eager_context_ = new EagerContext(
355         options,
356         tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
357         /* async */ false, device_manager_,
358         /* device_mgr_owned */ false, /* rendezvous */ nullptr,
359         /* cluster_flr */ nullptr);
360     corert_ = CreateCoreRuntime();
361     fallback_op_handler_ = CreateOpHandler();
362     cpu_op_handler_ = CreateOpHandler();
363     gpu_op_handler_ = CreateOpHandler();
364     corert_->RegisterOpHandler(kFullCPU, cpu_op_handler_);
365     corert_->RegisterOpHandler(kFullGPU, gpu_op_handler_);
366 
367     selector_ = std::make_unique<EagerOpHandlerSelector>(
368         corert_.get(), eager_context_, fallback_op_handler_,
369         /*pin_small_ops_to_cpu=*/true);
370   }
371 
~SelectorTest()372   ~SelectorTest() override {
373     delete device_manager_;
374     if (eager_context_) {
375       eager_context_->Unref();
376     }
377   }
378 
selector()379   EagerOpHandlerSelector* selector() { return selector_.get(); }
380 
Init()381   void Init() {}
382 
383  protected:
CreateOpHandler()384   OpHandler* CreateOpHandler() {
385     auto expected_op_handler = tfrt::CreateNullOpHandler(corert_.get());
386     assert(expected_op_handler);
387     return std::move(expected_op_handler.get());
388   }
389 
390   DynamicDeviceMgr* device_manager_;
391   EagerContext* eager_context_;
392   std::unique_ptr<CoreRuntime> corert_;
393   OpHandler* fallback_op_handler_;
394   OpHandler* cpu_op_handler_;
395   OpHandler* gpu_op_handler_;
396   std::unique_ptr<EagerOpHandlerSelector> selector_;
397 };
398 
TEST_F(SelectorTest,PinSmallOpToCpuTest)399 TEST_F(SelectorTest, PinSmallOpToCpuTest) {
400   auto op = std::make_unique<FakeOperation>();
401   tensorflow::core::RefCountPtr<FakeTensorHandle> cpu_tensor(
402       new FakeTensorHandle(kFullCPU, tensorflow::DT_INT32));
403   tensorflow::core::RefCountPtr<FakeTensorHandle> gpu_tensor(
404       new FakeTensorHandle(kFullGPU, tensorflow::DT_INT32));
405 
406   tensorflow::Status s;
407   TF_ASSERT_OK(op->Reset("TestOp", kFullGPU));
408   TF_ASSERT_OK(op->AddInput(cpu_tensor.get()));
409   OpHandler* op_handler = nullptr;
410   s = selector()->SelectFromArguments(*op, &op_handler);
411   ASSERT_EQ(s, ::tensorflow::OkStatus());
412   ASSERT_TRUE(static_cast<bool>(op_handler));
413   ASSERT_EQ(op_handler, cpu_op_handler_);
414 
415   op_handler = nullptr;
416   TF_ASSERT_OK(op->Reset("TestOp", kFullGPU));
417   TF_ASSERT_OK(op->AddInput(gpu_tensor.get()));
418   s = selector()->SelectFromArguments(*op, &op_handler);
419   ASSERT_EQ(s, ::tensorflow::OkStatus());
420   ASSERT_FALSE(static_cast<bool>(op_handler));
421   s = selector()->SelectFromNodeDef(*op, &op->GetAttrs()->BuildNodeDef(),
422                                     &op_handler);
423   ASSERT_EQ(s, ::tensorflow::OkStatus());
424   ASSERT_TRUE(static_cast<bool>(op_handler));
425   ASSERT_EQ(op_handler, gpu_op_handler_);
426 }
427 
TEST_F(SelectorTest,PinResourceTest)428 TEST_F(SelectorTest, PinResourceTest) {
429   auto op = std::make_unique<FakeOperation>();
430   tensorflow::core::RefCountPtr<FakeTensorHandle> cpu_tensor(
431       new FakeTensorHandle(kFullCPU, tensorflow::DT_RESOURCE));
432   tensorflow::core::RefCountPtr<FakeTensorHandle> gpu_tensor(
433       new FakeTensorHandle(kFullGPU, tensorflow::DT_RESOURCE));
434 
435   tensorflow::Status s;
436   TF_ASSERT_OK(op->Reset("TestOp", kFullGPU));
437   TF_ASSERT_OK(op->AddInput(cpu_tensor.get()));
438   OpHandler* op_handler = nullptr;
439   s = selector()->SelectFromArguments(*op, &op_handler);
440   ASSERT_EQ(s, ::tensorflow::OkStatus());
441   ASSERT_TRUE(static_cast<bool>(op_handler));
442   ASSERT_EQ(op_handler, cpu_op_handler_);
443 
444   op_handler = nullptr;
445   TF_ASSERT_OK(op->Reset("TestOp", kFullCPU));
446   TF_ASSERT_OK(op->AddInput(gpu_tensor.get()));
447   s = selector()->SelectFromArguments(*op, &op_handler);
448   ASSERT_EQ(s, ::tensorflow::OkStatus());
449   ASSERT_TRUE(static_cast<bool>(op_handler));
450   ASSERT_EQ(op_handler, gpu_op_handler_);
451 }
452 
TEST_F(SelectorTest,InvalidDeviceNameTest)453 TEST_F(SelectorTest, InvalidDeviceNameTest) {
454   auto op = std::make_unique<FakeOperation>();
455 
456   TF_ASSERT_OK(op->Reset("TestOp", "invalid_device_name"));
457 
458   tensorflow::Status s;
459   OpHandler* op_handler = nullptr;
460   s = selector()->SelectFromNodeDef(*op, &op->GetAttrs()->BuildNodeDef(),
461                                     &op_handler);
462   ASSERT_EQ(s.code(), tensorflow::error::INVALID_ARGUMENT);
463   ASSERT_FALSE(static_cast<bool>(op_handler));
464   EXPECT_TRUE(
465       absl::StrContains(s.error_message(), "Failed to parse device name"));
466 }
467 
TEST_F(SelectorTest,SoftPlacementTest)468 TEST_F(SelectorTest, SoftPlacementTest) {
469   auto op = std::make_unique<FakeOperation>();
470 
471   TF_ASSERT_OK(op->Reset("TestOp", "/device:FakeGPU:99"));
472   tensorflow::Status s;
473   OpHandler* op_handler = nullptr;
474   s = selector()->SelectFromNodeDef(*op, &op->GetAttrs()->BuildNodeDef(),
475                                     &op_handler);
476   ASSERT_EQ(s, ::tensorflow::OkStatus());
477   ASSERT_TRUE(static_cast<bool>(op_handler)) << StrCat(s.error_message());
478   ASSERT_EQ(op_handler, gpu_op_handler_);
479 }
480 
TEST_F(SelectorTest,HigherPriorityDeviceTest)481 TEST_F(SelectorTest, HigherPriorityDeviceTest) {
482   auto op = std::make_unique<FakeOperation>();
483 
484   tensorflow::Status s;
485   TF_ASSERT_OK(op->Reset("TestOp", ""));
486   OpHandler* op_handler = nullptr;
487   s = selector()->SelectFromNodeDef(*op, &op->GetAttrs()->BuildNodeDef(),
488                                     &op_handler);
489   ASSERT_EQ(s, ::tensorflow::OkStatus());
490   ASSERT_TRUE(static_cast<bool>(op_handler));
491   ASSERT_EQ(op_handler, gpu_op_handler_);
492 }
493 
494 }  // namespace
495 }  // namespace tf
496 }  // namespace tfrt
497