xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/eager/custom_device_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/custom_device.h"
17 
18 #include "tensorflow/core/common_runtime/device_mgr.h"
19 #include "tensorflow/core/common_runtime/eager/context.h"
20 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
21 #include "tensorflow/core/common_runtime/eager/placement_utils.h"
22 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
23 #include "tensorflow/core/framework/device_factory.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace eager {
29 namespace {
30 
31 using ::testing::ContainsRegex;
32 using ::testing::HasSubstr;
33 
34 class TestCustomDevice : public CustomDevice {
35  public:
TestCustomDevice(std::string name)36   explicit TestCustomDevice(std::string name) : name_(name) {}
name()37   const std::string& name() override { return name_; }
CopyTensorToDevice(ImmediateExecutionTensorHandle * tensor,ImmediateExecutionTensorHandle ** result)38   Status CopyTensorToDevice(ImmediateExecutionTensorHandle* tensor,
39                             ImmediateExecutionTensorHandle** result) override {
40     tensor->Ref();
41     *result = tensor;
42     return OkStatus();
43   }
CopyTensorFromDevice(ImmediateExecutionTensorHandle * tensor,const std::string & target_device_name,ImmediateExecutionTensorHandle ** result)44   Status CopyTensorFromDevice(
45       ImmediateExecutionTensorHandle* tensor,
46       const std::string& target_device_name,
47       ImmediateExecutionTensorHandle** result) override {
48     tensor->Ref();
49     *result = tensor;
50     return OkStatus();
51   }
Execute(const ImmediateExecutionOperation * op,ImmediateExecutionTensorHandle ** retvals,int * num_retvals)52   Status Execute(const ImmediateExecutionOperation* op,
53                  ImmediateExecutionTensorHandle** retvals,
54                  int* num_retvals) override {
55     return errors::Unimplemented("Not implemented");
56   }
57 
Pack(absl::Span<ImmediateExecutionTensorHandle * > handles,ImmediateExecutionTensorHandle ** result)58   Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
59               ImmediateExecutionTensorHandle** result) override {
60     return errors::Unimplemented("Packing is not implemented");
61   }
62 
63  private:
64   std::string name_;
65 };
66 
67 class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle {
68  public:
TestCustomDeviceTensorHandle(ImmediateExecutionContext * context,TestCustomDevice * device,tensorflow::DataType dtype,int64_t length)69   TestCustomDeviceTensorHandle(ImmediateExecutionContext* context,
70                                TestCustomDevice* device,
71                                tensorflow::DataType dtype, int64_t length)
72       : CustomDeviceTensorHandle(context, device, dtype), length_(length) {}
73 
DevicePointer() const74   void* DevicePointer() const override { return nullptr; }
NumDims(int * num_dims) const75   Status NumDims(int* num_dims) const override {
76     *num_dims = 1;
77     return OkStatus();
78   }
Dim(int dim_index,int64_t * dim) const79   Status Dim(int dim_index, int64_t* dim) const override {
80     if (dim_index == 0) {
81       *dim = length_;
82       return OkStatus();
83     } else {
84       return errors::Internal("Dim out of bounds");
85     }
86   }
87 
SummarizeValue(std::string & summary) const88   Status SummarizeValue(std::string& summary) const override {
89     summary = std::string("TestValue");
90     return OkStatus();
91   }
92 
93  private:
94   const int64_t length_;
95 };
96 
TEST(CustomDevice,TestTensorHandle)97 TEST(CustomDevice, TestTensorHandle) {
98   StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
99       "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
100   core::RefCountPtr<EagerContext> ctx(new EagerContext(
101       SessionOptions(),
102       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
103       &device_mgr, false, nullptr, nullptr));
104   std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
105   TestCustomDevice device(device_name);
106   core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
107       new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT,
108                                        /*length=*/3));
109   Status s;
110   std::string device_type = tensor->DeviceType(&s);
111   ASSERT_TRUE(s.ok()) << s.error_message();
112   EXPECT_EQ("CUSTOM", device_type);
113   int device_index = tensor->DeviceId(&s);
114   ASSERT_TRUE(s.ok()) << s.error_message();
115   EXPECT_EQ(15, device_index);
116   int64_t num_elements = 0;
117   s = tensor->NumElements(&num_elements);
118   ASSERT_TRUE(s.ok()) << s.error_message();
119   EXPECT_EQ(3, num_elements);
120   EXPECT_THAT(
121       tensor->DebugString(),
122       ContainsRegex(
123           R"re(TensorHandle\(TestValue, shape=\[3\], dtype=DT_FLOAT, device=.*\))re"));
124 }
125 
TEST(CustomDevice,TestTensorHandleUnknownDimNumElements)126 TEST(CustomDevice, TestTensorHandleUnknownDimNumElements) {
127   StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
128       "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
129   core::RefCountPtr<EagerContext> ctx(new EagerContext(
130       SessionOptions(),
131       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
132       &device_mgr, false, nullptr, nullptr));
133   std::string device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:15";
134   TestCustomDevice device(device_name);
135   core::RefCountPtr<TestCustomDeviceTensorHandle> tensor(
136       new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT,
137                                        /*length=*/-1));
138   int64_t num_elements;
139   Status s = tensor->NumElements(&num_elements);
140   EXPECT_FALSE(s.ok());
141   EXPECT_THAT(s.error_message(), HasSubstr("representing varying shapes"));
142 }
143 
TEST(CustomDevice,TestResourcePlacement)144 TEST(CustomDevice, TestResourcePlacement) {
145   StaticDeviceMgr device_mgr(DeviceFactory::NewDevice(
146       "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
147   core::RefCountPtr<EagerContext> ctx(new EagerContext(
148       SessionOptions(),
149       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, false,
150       &device_mgr, false, nullptr, nullptr));
151   std::string custom_device_name =
152       "/job:localhost/replica:0/task:0/device:CUSTOM:15";
153   TestCustomDevice custom_device(custom_device_name);
154   core::RefCountPtr<TestCustomDeviceTensorHandle> custom_float_tensor(
155       new TestCustomDeviceTensorHandle(ctx.get(), &custom_device, DT_FLOAT,
156                                        /*length=*/3));
157   core::RefCountPtr<TestCustomDeviceTensorHandle> custom_resource_tensor(
158       new TestCustomDeviceTensorHandle(ctx.get(), &custom_device, DT_RESOURCE,
159                                        /*length=*/3));
160 
161   Tensor resource_tensor(DT_RESOURCE, {});
162   Device* physical_device = device_mgr.ListDevices().at(0);
163   core::RefCountPtr<TensorHandle> physical_resource_tensor(
164       TensorHandle::CreateLocalHandle(std::move(resource_tensor),
165                                       physical_device, physical_device,
166                                       physical_device, ctx.get()));
167   Tensor float_tensor(DT_FLOAT, {});
168   core::RefCountPtr<TensorHandle> physical_float_tensor(
169       TensorHandle::CreateLocalHandle(std::move(float_tensor), physical_device,
170                                       physical_device, physical_device,
171                                       ctx.get()));
172   EagerOperation op(ctx.get());
173   TF_ASSERT_OK(op.Reset("AssignVariableOp", ""));
174   TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
175   TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
176   CustomDevice* placed_device = nullptr;
177   TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
178       &placed_device, op));
179   // MaybePinToCustomDevice has no opinion about ops which have physical
180   // resource-dtype inputs. They'll get placed on physical devices.
181   EXPECT_EQ(nullptr, placed_device);
182 
183   op.Clear();
184   TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str()));
185   TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
186   TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
187   placed_device = nullptr;
188   TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
189       &placed_device, op));
190   // Explicit placement onto a custom device also doesn't trigger custom device
191   // placement if there's a physical device resource input.
192   EXPECT_EQ(nullptr, placed_device);
193 
194   op.Clear();
195   TF_ASSERT_OK(
196       op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0"));
197   TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
198   placed_device = nullptr;
199   TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
200       &placed_device, op));
201   // Explicit placements typically override input-based placement onto a custom
202   // device.
203   EXPECT_EQ(nullptr, placed_device);
204 
205   op.Clear();
206   TF_ASSERT_OK(op.Reset("AssignVariableOp",
207                         "/job:localhost/replica:0/task:0/device:CPU:0"));
208   TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get()));
209   TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
210   placed_device = nullptr;
211   TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
212       &placed_device, op));
213   // Even with an explicit physical device placement, custom device resource
214   // inputs place the op on the custom device.
215   ASSERT_NE(placed_device, nullptr);
216   EXPECT_EQ(&custom_device, placed_device);
217 }
218 
219 }  // namespace
220 }  // namespace eager
221 }  // namespace tensorflow
222