xref: /aosp_15_r20/external/tensorflow/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
17 
18 #include <array>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_experimental.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 // NOTE(allenl): These tests currently go through TFE_Execute and so are
27 // integration testing rather than purely testing the parallel device. They
28 // correspond fairly well to the implementation, but testing the C++ directly is
29 // another option.
30 
31 namespace tensorflow {
32 namespace parallel_device {
33 
Create(TFE_Context * context,TF_DataType type,const int64_t * dims,const int num_dims,const char * device,TF_Status * status)34 Variable* Variable::Create(TFE_Context* context, TF_DataType type,
35                            const int64_t* dims, const int num_dims,
36                            const char* device, TF_Status* status) {
37   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
38       TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
39   if (TF_GetCode(status) != TF_OK) return nullptr;
40   TFE_OpSetAttrType(op.get(), "dtype", type);
41   TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
42   TFE_OpSetAttrString(op.get(), "container", "", 0);
43   // Use the special GUID for no buffer sharing
44   //
45   // TODO(allenl): Should we provide a better API for this? AFAIK this is the
46   // only reasonable way to make variables with no aliasing using the eager C
47   // API.
48   std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
49   TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
50                       no_sharing.length());
51   TFE_OpSetDevice(op.get(), device, status);
52   if (TF_GetCode(status) != TF_OK) return nullptr;
53   TFE_TensorHandle* var_handle = nullptr;
54   int num_retvals = 1;
55   TFE_Execute(op.get(), &var_handle, &num_retvals, status);
56   if (TF_GetCode(status) != TF_OK) return nullptr;
57   return new Variable(var_handle, type);
58 }
59 
Destroy(TFE_Context * context,TF_Status * status)60 void Variable::Destroy(TFE_Context* context, TF_Status* status) {
61   // Free the backing buffer for the variable.
62   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
63       TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
64   if (TF_GetCode(status) != TF_OK) return;
65   TFE_OpAddInput(op.get(), handle_, status);
66   if (TF_GetCode(status) != TF_OK) return;
67   const char* device = TFE_TensorHandleDeviceName(handle_, status);
68   if (TF_GetCode(status) != TF_OK) return;
69   TFE_OpSetDevice(op.get(), device, status);
70   if (TF_GetCode(status) != TF_OK) return;
71   int num_retvals = 0;
72   TFE_Execute(op.get(), nullptr, &num_retvals, status);
73   if (TF_GetCode(status) != TF_OK) return;
74   // Delete the variable handle itself.
75   TFE_DeleteTensorHandle(handle_);
76 }
77 
Read(TFE_Context * context,TF_Status * status)78 TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
79   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
80       TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
81   if (TF_GetCode(status) != TF_OK) return nullptr;
82   TFE_OpAddInput(op.get(), handle_, status);
83   if (TF_GetCode(status) != TF_OK) return nullptr;
84   const char* device = TFE_TensorHandleDeviceName(handle_, status);
85   if (TF_GetCode(status) != TF_OK) return nullptr;
86   TFE_OpSetDevice(op.get(), device, status);
87   if (TF_GetCode(status) != TF_OK) return nullptr;
88   TFE_OpSetAttrType(op.get(), "dtype", type_);
89   int num_retvals = 1;
90   TFE_TensorHandle* var_value = nullptr;
91   TFE_Execute(op.get(), &var_value, &num_retvals, status);
92   if (TF_GetCode(status) != TF_OK) return nullptr;
93   return TensorHandlePtr(var_value);
94 }
95 
GeneralAssignment(const char * op_name,TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)96 void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
97                                  TFE_TensorHandle* value, TF_Status* status) {
98   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
99       TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
100   if (TF_GetCode(status) != TF_OK) return;
101   TFE_OpSetAttrType(op.get(), "dtype", type_);
102   TFE_OpAddInput(op.get(), handle_, status);
103   if (TF_GetCode(status) != TF_OK) return;
104   TFE_OpAddInput(op.get(), value, status);
105   if (TF_GetCode(status) != TF_OK) return;
106   const char* device = TFE_TensorHandleDeviceName(handle_, status);
107   if (TF_GetCode(status) != TF_OK) return;
108   TFE_OpSetDevice(op.get(), device, status);
109 
110   int num_retvals = 0;
111   TFE_Execute(op.get(), nullptr, &num_retvals, status);
112   if (TF_GetCode(status) != TF_OK) return;
113 }
114 
AssignAdd(TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)115 void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
116                          TF_Status* status) {
117   GeneralAssignment("AssignAddVariableOp", context, value, status);
118 }
119 
Assign(TFE_Context * context,TFE_TensorHandle * value,TF_Status * status)120 void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
121                       TF_Status* status) {
122   GeneralAssignment("AssignVariableOp", context, value, status);
123 }
124 
125 // Passed to `TF_NewTensor` to indicate how an array of floats should be
126 // deleted.
FloatDeallocator(void * data,size_t,void * arg)127 static void FloatDeallocator(void* data, size_t, void* arg) {
128   delete[] static_cast<float*>(data);
129 }
130 
131 // Creates a TFE_TensorHandle with value `v`.
FloatTensorHandle(float v,TF_Status * status)132 TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
133   const int num_bytes = sizeof(float);
134   float* values = new float[1];
135   values[0] = v;
136   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
137       TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
138                    nullptr),
139       TF_DeleteTensor);
140   return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
141 }
142 
143 // Creates a rank-one TFE_TensorHandle with value `v`.
VectorFloatTensorHandle(const std::vector<float> & v,TF_Status * status)144 TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
145                                         TF_Status* status) {
146   const int num_bytes = v.size() * sizeof(float);
147   float* values = new float[v.size()];
148   memcpy(values, v.data(), num_bytes);
149   int64_t dims = v.size();
150   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
151       TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
152                    &FloatDeallocator, nullptr),
153       TF_DeleteTensor);
154   return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
155 }
156 
157 // Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
158 template <std::size_t num_replicas>
ExtractPerDeviceValues(TFE_Context * context,TFE_TensorHandle * input,std::array<TensorHandlePtr,num_replicas> * components,TF_Status * status)159 void ExtractPerDeviceValues(
160     TFE_Context* context, TFE_TensorHandle* input,
161     std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
162   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
163       TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
164   if (TF_GetCode(status) != TF_OK) return;
165   TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
166   TFE_OpAddInput(op.get(), input, status);
167   if (TF_GetCode(status) != TF_OK) return;
168   const char* device = TFE_TensorHandleDeviceName(input, status);
169   if (TF_GetCode(status) != TF_OK) return;
170   TFE_OpSetDevice(op.get(), device, status);
171   if (TF_GetCode(status) != TF_OK) return;
172 
173   TFE_TensorHandle* result_handles[num_replicas];
174   int num_retvals = num_replicas;
175   TFE_Execute(op.get(), result_handles, &num_retvals, status);
176   if (TF_GetCode(status) != TF_OK) return;
177   for (int i = 0; i < num_replicas; ++i) {
178     (*components)[i].reset(result_handles[i]);
179   }
180 }
181 
Multiply(TFE_Context * context,TFE_TensorHandle * first,TFE_TensorHandle * second,TF_Status * status)182 TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
183                          TFE_TensorHandle* second, TF_Status* status) {
184   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
185       TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
186   if (TF_GetCode(status) != TF_OK) return nullptr;
187   TFE_OpAddInput(op.get(), first, status);
188   if (TF_GetCode(status) != TF_OK) return nullptr;
189   TFE_OpAddInput(op.get(), second, status);
190   if (TF_GetCode(status) != TF_OK) return nullptr;
191   const char* first_device = TFE_TensorHandleDeviceName(first, status);
192   if (TF_GetCode(status) != TF_OK) return nullptr;
193   TFE_OpSetDevice(op.get(), first_device, status);
194 
195   TFE_TensorHandle* result_handle;
196   int num_retvals = 1;
197   TFE_Execute(op.get(), &result_handle, &num_retvals, status);
198   if (TF_GetCode(status) != TF_OK) return nullptr;
199   return TensorHandlePtr(result_handle);
200 }
201 
202 // Create and modify a variable placed on a parallel device which composes
203 // `first_device` and `second_device`.
BasicTestsForTwoDevices(TFE_Context * context,const char * first_device,const char * second_device)204 void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
205                              const char* second_device) {
206   // Register the custom device
207   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
208       TF_NewStatus(), TF_DeleteStatus);
209   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
210   std::array<const char*, 2> underlying_devices{first_device, second_device};
211   RegisterParallelDevice(context, device_name, underlying_devices,
212                          status.get());
213   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
214 
215   // Create a variable handle (uninitialized to start) placed on the parallel
216   // device.
217   std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
218     to_delete->Destroy(context, status.get());
219     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
220     delete to_delete;
221   };
222   std::unique_ptr<Variable, decltype(variable_deleter)> variable(
223       Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
224                        status.get()),
225       variable_deleter);
226   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
227 
228   // Assign an initial value to the variable, mirroring it to each component
229   // device.
230   {
231     TensorHandlePtr initial_value_cpu = FloatTensorHandle(20., status.get());
232     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
233     std::array<TFE_TensorHandle*, 2> components{initial_value_cpu.get(),
234                                                 initial_value_cpu.get()};
235     TensorHandlePtr initial_value =
236         CreatePerDeviceValues(context, components, device_name, status.get());
237     variable->Assign(context, initial_value.get(), status.get());
238     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
239   }
240 
241   // Read from the variable and verify that we have a parallel tensor.
242   {
243     TensorHandlePtr read = variable->Read(context, status.get());
244     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
245     std::array<TensorHandlePtr, 2> components;
246     ExtractPerDeviceValues(context, read.get(), &components, status.get());
247     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
248 
249     ExpectScalarEq<float>(components[0].get(), 20.);
250     ExpectScalarEq<float>(components[1].get(), 20.);
251 
252     std::string first_device =
253         TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
254     ASSERT_EQ(underlying_devices[0], first_device);
255     std::string second_device =
256         TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
257     ASSERT_EQ(underlying_devices[1], second_device);
258   }
259 
260   // Add a parallel tensor with different values on each device to the variable.
261   {
262     TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
263     TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
264     std::array<TFE_TensorHandle*, 2> components{value_one.get(),
265                                                 value_two.get()};
266     TensorHandlePtr combined_value =
267         CreatePerDeviceValues(context, components, device_name, status.get());
268     variable->AssignAdd(context, combined_value.get(), status.get());
269   }
270 
271   // Read the variable and verify that each component has the right modified
272   // value.
273   {
274     TensorHandlePtr read = variable->Read(context, status.get());
275     std::array<TensorHandlePtr, 2> components;
276     ExtractPerDeviceValues(context, read.get(), &components, status.get());
277     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
278 
279     ExpectScalarEq<float>(components[0].get(), 23.);
280     ExpectScalarEq<float>(components[1].get(), 18.);
281 
282     std::string first_device =
283         TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
284     ASSERT_EQ(underlying_devices[0], first_device);
285     std::string second_device =
286         TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
287     ASSERT_EQ(underlying_devices[1], second_device);
288   }
289 }
290 
291 }  // namespace parallel_device
292 }  // namespace tensorflow
293