1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/cc/dtensor_device.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/base/attributes.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/match.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/string_view.h"
33 #include "absl/strings/strip.h"
34 #include "tensorflow/c/c_api_experimental.h"
35 #include "tensorflow/c/eager/c_api.h"
36 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
37 #include "tensorflow/c/eager/tfe_context_internal.h"
38 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
39 #include "tensorflow/c/tf_datatype.h"
40 #include "tensorflow/c/tf_status.h"
41 #include "tensorflow/c/tf_status_helper.h"
42 #include "tensorflow/c/tf_tensor_internal.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h"
45 #include "tensorflow/core/common_runtime/device_set.h"
46 #include "tensorflow/core/common_runtime/eager/context.h"
47 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
48 #include "tensorflow/core/common_runtime/graph_constructor.h"
49 #include "tensorflow/core/common_runtime/shape_refiner.h"
50 #include "tensorflow/core/framework/attr_value.pb.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/function.pb.h"
53 #include "tensorflow/core/framework/graph_to_functiondef.h"
54 #include "tensorflow/core/framework/node_def_builder.h"
55 #include "tensorflow/core/framework/node_def_util.h"
56 #include "tensorflow/core/framework/op.h"
57 #include "tensorflow/core/framework/tensor_shape.h"
58 #include "tensorflow/core/graph/algorithm.h"
59 #include "tensorflow/core/graph/graph.h"
60 #include "tensorflow/core/lib/strings/proto_serialization.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/fingerprint.h"
63 #include "tensorflow/core/platform/types.h"
64 #include "tensorflow/core/profiler/lib/traceme.h"
65 #include "tensorflow/core/util/dump_graph.h"
66 #include "tensorflow/dtensor/cc/constants.h"
67 #include "tensorflow/dtensor/cc/dstatus.h"
68 #include "tensorflow/dtensor/cc/dtensor_device_util.h"
69 #include "tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h"
70 #include "tensorflow/dtensor/cc/small_constant_optimization.h"
71 #include "tensorflow/dtensor/cc/tensor_layout.h"
72 #include "tensorflow/dtensor/cc/tpu_system_interface.h"
73 #include "tensorflow/dtensor/proto/layout.pb.h"
74 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
75 #include "tensorflow/stream_executor/tpu/tpu_topology.h"
76
77 namespace tensorflow {
78 namespace dtensor {
79
80 // TODO(b/189332820): Replace this with a Partitioner stub swapped in by the
81 // Copybara workflow.
PipeliningPartitionerRun(const absl::flat_hash_map<std::string,const MeshWithParallelDevice * > * device_name_to_mesh_device,FunctionLibraryDefinition * flib_def,DTensorMlirPassRunner * pass_runner,const FunctionDef & fdef,const NameAttrList & eager_attributes,const std::vector<TensorWithLayout * > & inputs,const DeviceSet & device_set,int num_outputs)82 StatusOr<ExecutionFunctions> ABSL_ATTRIBUTE_WEAK PipeliningPartitionerRun(
83 const absl::flat_hash_map<std::string, const MeshWithParallelDevice*>*
84 device_name_to_mesh_device,
85 FunctionLibraryDefinition* flib_def, DTensorMlirPassRunner* pass_runner,
86 const FunctionDef& fdef, const NameAttrList& eager_attributes,
87 const std::vector<TensorWithLayout*>& inputs, const DeviceSet& device_set,
88 int num_outputs) {
89 // The actual definition is in the pipelining package.
90 return errors::Unimplemented("DTensor pipelining is unavailable.");
91 }
92
93 class DTensorDevice {
94 public:
DTensorDevice(absl::string_view name)95 explicit DTensorDevice(absl::string_view name)
96 : name_(name),
97 same_shape_policy_enabled_(false),
98 cancellation_manager_(std::make_unique<CancellationManager>()) {}
99
AddMesh(std::unique_ptr<MeshWithParallelDevice> mesh,bool is_host_mesh)100 void AddMesh(std::unique_ptr<MeshWithParallelDevice> mesh,
101 bool is_host_mesh) {
102 // TODO(b/168730933): Consider passing a cheaper int64_t mesh identifier.
103 if (is_host_mesh) {
104 std::string& tpu_host_mesh = Mesh::tpu_host_mesh();
105 const std::string new_tpu_host_mesh = mesh->mesh_config().ToString();
106 if (!tpu_host_mesh.empty()) {
107 // TODO(b/180046115): Add per-TPU-mesh host mesh bookkeeping.
108 LOG(WARNING)
109 << "A new TPU host mesh is overwriting the old TPU host mesh. The "
110 "old TPU mesh cannot be used in sea of donuts mode anymore.";
111 }
112 tpu_host_mesh.assign(new_tpu_host_mesh);
113 }
114 // For idempotency, don't register the same mesh twice.
115 if (!mesh_to_device_map_.insert({mesh->mesh_config(), std::move(mesh)})
116 .second)
117 return;
118 if (!default_mesh_) {
119 global_default_mesh_ = mesh_to_device_map_.begin()->second.get();
120 default_mesh_ = global_default_mesh_;
121 }
122 }
123
124 // Returns sub meshes of pipelining.
125 // Key is the name of a composite device.
126 StatusOr<absl::flat_hash_map<std::string, const MeshWithParallelDevice*>>
PipelineSubMeshes(TFE_Context * context)127 PipelineSubMeshes(TFE_Context* context) {
128 absl::flat_hash_map<std::string, const MeshWithParallelDevice*>
129 device_to_mesh;
130 for (const auto& pair : mesh_to_device_map_) {
131 TF_ASSIGN_OR_RETURN(CompositeDevice * device,
132 pair.second->FindOrCreateCompositeDevice(context));
133 if (device != nullptr) {
134 device_to_mesh[pair.second->composite_device()->name()] =
135 pair.second.get();
136 }
137 }
138 return device_to_mesh;
139 }
140
141 // Runs an operation on the DTensorDevice,
142 //
143 // Ignoring the placement of the original op (TFE_OpGetDevice(original_op)).
144 // This indicates whether the user explicitly placed the op on the DTensor
145 // device (vs. having it placed on the DTensor device because an input was
146 // placed there), but DTensor is doing type-based dispatch and so handles
147 // these cases identically at the moment.
148 void Execute(const TFE_Op* original_op, int* num_outputs,
149 TFE_TensorHandle** outputs, TF_Status* status);
150
SetDefaultLayout(Layout layout)151 void SetDefaultLayout(Layout layout) { default_layout_.emplace(layout); }
ClearDefaultLayout()152 void ClearDefaultLayout() { default_layout_.reset(); }
SetDefaultMesh(Mesh mesh)153 void SetDefaultMesh(Mesh mesh) {
154 default_mesh_ = mesh_to_device_map_.at(mesh).get();
155 }
ClearDefaultMesh()156 void ClearDefaultMesh() { default_mesh_ = global_default_mesh_; }
SetSameShapePolicy(bool enabled)157 void SetSameShapePolicy(bool enabled) {
158 same_shape_policy_enabled_ = enabled;
159 }
160
SetTPUCoreIDs(const std::string & mesh_name,const std::vector<int> & tpu_core_ids)161 Status SetTPUCoreIDs(const std::string& mesh_name,
162 const std::vector<int>& tpu_core_ids) {
163 if (VLOG_IS_ON(1)) {
164 LOG(INFO) << "Setting TPU core IDs for "
165 << (mesh_name.empty() ? "default mesh" : mesh_name) << ": ";
166 for (auto i : tpu_core_ids) {
167 LOG(INFO) << i;
168 }
169 }
170 // Setting the default mesh under an empty name repeatedly is fine, which
171 // happens when dtensor_initialize_tpu_system is called multiple times
172 // especially in tests. All the set mappings should be the same anyway.
173 if (!mesh_name.empty() && Mesh::tpu_core_ids().count(mesh_name) > 0) {
174 return errors::AlreadyExists("Mesh name already in use: ", mesh_name);
175 }
176 Mesh::tpu_core_ids()[mesh_name].assign(tpu_core_ids.begin(),
177 tpu_core_ids.end());
178 return OkStatus();
179 }
180
ClearTPUCoreIDs()181 void ClearTPUCoreIDs() { Mesh::tpu_core_ids().clear(); }
182
TPUCoreIDsToLocations(TFE_Context * context,const std::vector<int> & tpu_core_ids)183 std::vector<std::vector<int>> TPUCoreIDsToLocations(
184 TFE_Context* context, const std::vector<int>& tpu_core_ids) {
185 TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
186 if (tpu_system == nullptr) {
187 VLOG(1) << "Calling TPUCoreIDsToLocations on the default TPU system.";
188 std::vector<std::vector<int>> tpu_core_locations;
189 tpu_core_locations.reserve(tpu_core_ids.size());
190 tpu::TpuPlatformInterface* tpu_platform =
191 tpu::TpuPlatformInterface::GetRegisteredPlatform();
192 if (tpu_platform == nullptr) {
193 LOG(WARNING) << "No TPU platform is found.";
194 return {{}};
195 }
196 if (!tpu_platform->Initialized()) {
197 LOG(WARNING) << "TPU platform is not initialized.";
198 return {{}};
199 }
200 tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology();
201
202 for (const int& tpu_core_id : tpu_core_ids) {
203 tpu::TpuCoreLocationExternal core =
204 tpu_topology.CoreForId(TpuCoreTypeEnum::kTensorCore, tpu_core_id);
205 tpu::TpuDimensionsExternal tpu_chip_location = core.chip_coordinates();
206 tpu_core_locations.push_back({tpu_chip_location.x, tpu_chip_location.y,
207 tpu_chip_location.z, core.index()});
208 }
209 return tpu_core_locations;
210 } else {
211 VLOG(1) << "Calling TPUCoreIDsToLocations on a preferred TPU system.";
212 return tpu_system->TPUCoreIDsToLocations(context, tpu_core_ids);
213 }
214 }
215
TPUCoreLocationsToIDs(TFE_Context * context,const std::vector<std::vector<int>> & tpu_core_locations)216 std::vector<int> TPUCoreLocationsToIDs(
217 TFE_Context* context,
218 const std::vector<std::vector<int>>& tpu_core_locations) {
219 TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
220 if (tpu_system == nullptr) {
221 VLOG(1) << "Calling TPUCoreLocationsToIDs on the default TPU system.";
222 std::vector<int> tpu_core_ids;
223 tpu_core_ids.reserve(tpu_core_locations.size());
224 tpu::TpuPlatformInterface* tpu_platform =
225 tpu::TpuPlatformInterface::GetRegisteredPlatform();
226 if (tpu_platform == nullptr) {
227 LOG(WARNING) << "No TPU platform is found.";
228 return {};
229 }
230 if (!tpu_platform->Initialized()) {
231 LOG(WARNING) << "TPU platform is not initialized.";
232 return {};
233 }
234 tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology();
235
236 for (const std::vector<int>& tpu_core_location : tpu_core_locations) {
237 tpu::TpuCoreLocationExternal core = tpu_topology.Core(
238 TpuCoreTypeEnum::kTensorCore, tpu_core_location[0],
239 tpu_core_location[1], tpu_core_location[2], tpu_core_location[3]);
240 tpu_core_ids.push_back(core.Id());
241 }
242 return tpu_core_ids;
243 } else {
244 VLOG(1) << "Calling TPUCoreLocationsToIDs on a preferred TPU system.";
245 return tpu_system->TPUCoreLocationsToIDs(context, tpu_core_locations);
246 }
247 }
248
249 // Waits for ops to finish in ALL meshes as we share the cancellation manager.
AsyncWait(TFE_Context * context,TF_Status * status)250 void AsyncWait(TFE_Context* context, TF_Status* status) {
251 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> first_bad_status(
252 nullptr, TF_DeleteStatus);
253
254 for (const auto& pair : mesh_to_device_map_) {
255 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status(
256 TF_NewStatus(), TF_DeleteStatus);
257
258 pair.second->parallel_device().AsyncWait(context,
259 async_wait_status.get());
260
261 TF_Code error_code = TF_GetCode(async_wait_status.get());
262 if (error_code != TF_OK &&
263 (first_bad_status == nullptr ||
264 TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
265 first_bad_status.reset(TF_NewStatus());
266 TF_SetStatus(first_bad_status.get(), error_code,
267 TF_Message(async_wait_status.get()));
268 }
269 }
270
271 if (first_bad_status != nullptr) {
272 TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
273 TF_Message(first_bad_status.get()));
274 }
275
276 // Reset the global function rendezvous, which otherwise stores a failure
277 // state.
278 tensorflow::unwrap(context)->ResetGlobalRendezvousForFunction();
279
280 // Reset the cancellation manager on (potential) failure so we don't cancel
281 // future ops. This is only safe because we have just cleared pending async
282 // nodes, which may have had a reference to he cancellation manager.
283 cancellation_manager_ = std::make_unique<CancellationManager>();
284 }
285
286 TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs,
287 TFE_TensorHandle** inputs,
288 const std::string& string_layout, TF_Status* status);
289
290 std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context,
291 TFE_TensorHandle* input,
292 TF_Status* status);
293
294 // Return the layout for the input tensor.
295 std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input,
296 TF_Status* status);
297
298 TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs,
299 TFE_TensorHandle** indices,
300 TFE_TensorHandle** values,
301 TFE_TensorHandle** shapes,
302 const std::string& string_layout,
303 TF_Status* status);
304
305 bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input,
306 TF_Status* status);
307
308 std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount(
309 TFE_Context* context, TF_Status* status) const;
310
311 private:
312 // If the `operation_name` of an op indicates a custom DTensor op (e.g.
313 // CopyToMesh), then separately handle those custom ops instead of running
314 // default DTensor graph compilation.
315 void MaybeHandleDTensorCustomOps(
316 const char* operation_name, const int num_inputs,
317 const TFE_OpAttrs* attributes, TFE_Context* context,
318 TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs,
319 bool* is_custom_dtensor_op, TF_Status* status);
320
321 // Copies non-dtensor eager tensor or DTensor to a mesh specified by
322 // `attributes`.
323 // Currently, only copy to replicated layout on target mesh is supported.
324 void CopyToMesh(TFE_Context* context, int num_inputs,
325 TFE_TensorHandle** inputs, const TFE_OpAttrs* attributes,
326 TFE_TensorHandle** outputs, int* num_outputs,
327 TF_Status* status);
328
329 // Update output layouts for eager ops based on same shape policy.
330 void UpdateOutputLayoutsWithSameShapePolicy(
331 const std::vector<PartialTensorShape>& global_output_shapes,
332 const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name,
333 tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts,
334 TF_Status* status);
335
336 // Takes the description of an operation and makes a function out of it with
337 // the same signature, running DTensor MLIR passes. Registers that function
338 // with `context`. `translated_function_name` is set to the name of the
339 // function.
340 //
341 // The resulting function expects a device ID as its first input.
342 void LowerToSPMDFunction(TFE_Context* context,
343 const std::vector<TensorWithLayout*>& inputs,
344 const DTensorOperation& doperation,
345 const TFE_OpAttrs* attributes, const int num_outputs,
346 const ExecutionFunctions** execution_functions,
347 TF_Status* status);
348
349 // Execute a given function.
350 void ExecuteFunctionAndWait(
351 TFE_Context* context, const TranslatedFunction* function_ptr,
352 const MeshWithParallelDevice* parallel_device_mesh,
353 const std::vector<parallel_device::ParallelTensor*>& parallel_inputs,
354 const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status);
355
356 // Implements `Execute` for operations which aren't special-cased in
357 void ExecuteRegularOperation(TFE_Context* context,
358 const std::vector<TensorWithLayout*>& inputs,
359 const DTensorOperation& doperation,
360 const TFE_OpAttrs* attributes, int* num_outputs,
361 TFE_TensorHandle** outputs, TF_Status* status);
362
363 // Wraps a TensorWithLayout into a TFE_TensorHandle.
364 TFE_TensorHandle* MakeLayoutTensorHandle(TFE_Context* context,
365 std::unique_ptr<TensorWithLayout> t,
366 TF_Status* status);
367
368 void RecordInShapeLayoutCache(const TensorWithLayout& tensor);
369
370 // Returns whether a given mesh is a remote mesh.
371 bool is_remote_mesh(const Mesh& mesh) const;
372
373 // The name of the device (the custom device)
374 std::string name_;
375 // Mesh configs with matching parallel devices.
376 //
377 // For now we just consider the first entry added to dtensor_device as the
378 // default mesh. Before we reach an agreement on this, we'll leave it as is.
379 absl::flat_hash_map<Mesh, std::unique_ptr<MeshWithParallelDevice>>
380 mesh_to_device_map_;
381 // TODO(hthu): Consider whether we want to preserve the default_mesh semantic.
382 // Current default mesh consistent to default_layout_. If default_layout_ is
383 // not set, it equals to global_default_mesh_.
384 const MeshWithParallelDevice* default_mesh_ = nullptr;
385 // The default mesh of a DTensorDevice, which cannot be modified once being
386 // set.
387 const MeshWithParallelDevice* global_default_mesh_ = nullptr;
388 // If the user has specified a default output layout.
389 absl::optional<Layout> default_layout_;
390
391 // Determines whether tensors with a shape previously associated with only one
392 // layout use that layout if nothing else can be inferred.
393 bool same_shape_policy_enabled_;
394
395 DTensorMlirPassRunner pass_runner_;
396
397 struct CachedLayout {
398 // The first layout seen with this shape
399 Layout layout;
400 // Whether the layout is unique for this shape
401 bool is_unique;
402 };
403 absl::flat_hash_map<int64_t, CachedLayout> shape_layout_cache_;
404
405 FunctionManager function_manager_;
406
407 // Records the function compilation cache hits and misses.
408 std::unordered_map<std::string, int> function_compilation_hits_and_misses_;
409
410 // Coordinates cancelling ops across meshes on error. Must outlive any queued
411 // async op launches, so we only reset it after seeing a failure status.
412 std::unique_ptr<CancellationManager> cancellation_manager_;
413
414 // Map each function_mesh_fingerprint (based on the set of the mesh involved)
415 // to the number of times of the function execution. The
416 // function_mesh_fingerprint and the counter together are used for generating
417 // the step id, which is used for rendezvous creation.
418 absl::flat_hash_map<uint64, uint64> func_mesh_fingerprint_to_step_counter_;
419 };
420
FingerprintShape(const absl::Span<const int64_t> shape)421 int64_t FingerprintShape(const absl::Span<const int64_t> shape) {
422 int64_t fprint = 0;
423 for (int64_t dim : shape) {
424 fprint = FingerprintCat64(fprint, dim);
425 }
426 return fprint;
427 }
428
DeviceIDs(TFE_Context * context,TF_Status * status) const429 parallel_device::ParallelTensor* MeshWithParallelDevice::DeviceIDs(
430 TFE_Context* context, TF_Status* status) const {
431 if (device_ids_tensor_ == nullptr) {
432 // Global device IDs sequentially increase.
433 //
434 // This is the assumption in the dtensor software stack. MLIR pass relies on
435 // this assumption to generate mesh coordinates for each core efficiently.
436 //
437 // The rule to set local ids and the mapping from global ids to real
438 // physical core index, e.g., TPU, is nontrivial unfortunately. It is
439 // possible to set identical mapping but the collective operation
440 // performance is terrible for most of cases.
441 //
442 // - For ICI-connected TPU slice, see go/dtensor-device-assignment-summary
443 // for guide how to create efficient core assignments toward peak
444 // performance.
445 //
446 // The global id to core assignment mapping is bridged by
447 // `Mesh::tpu_core_ids()` and consumed by `UpdateTPUCompileMetadata`.
448 //
449 // - For DCN-connected topology, we need to map different sections of the
450 // global ids to its real physical cores separately according to the
451 // runtime requirements. For example, for a 4x32 mesh, in which the outer
452 // dimension is connected via DCN and inner dimension is connected by ICI,
453 // the device assignments for inner dimension should typically form its
454 // own ring order (not plain physical core index) in each sub-meshes and
455 // the outer dimension should be assigned according to the real physical
456 // ring of DNC hosts.
457 //
458 // Note: In order to change this assumption, MLIR pass needs adjustment. One
459 // possible approach is to take a N-D mapping vector for N-D mesh and lookup
460 // the coordinates in MLIR, by consulting tensor layout as well, rather than
461 // calculation on-the-fly.
462
463 // LINT.IfChange
464 for (int64_t i = 0; i < mesh_config_.global_device_ids().size(); ++i) {
465 if (mesh_config_.global_device_ids()[i] - i !=
466 mesh_config_.global_device_ids()[0]) {
467 TF_SetStatus(
468 status, TF_INTERNAL,
469 absl::StrCat("Global device IDs should be consecutive: ",
470 absl::StrJoin(mesh_config_.global_device_ids(), ", "))
471 .c_str());
472 return nullptr;
473 }
474 }
475 // LINT.ThenChange(//tensorflow/dtensor/python/layout.py)
476
477 // Local device IDs are a subset of global device IDs, arranged in device
478 // ordinal order.
479 std::vector<int32_t> ids;
480 for (int64_t id : mesh_config_.local_device_ids()) {
481 ids.push_back(id);
482 }
483 VLOG(1) << "Parallel device IDs: " << absl::StrJoin(ids, ", ");
484 device_ids_tensor_ =
485 parallel_device_->ScalarsFromSequence<int32_t>(ids, context, status);
486 if (TF_GetCode(status) != TF_OK) return nullptr;
487 }
488 return device_ids_tensor_.get();
489 }
490
TensorWithLayoutNumDims(void * data,TF_Status * status)491 int TensorWithLayoutNumDims(void* data, TF_Status* status) {
492 return reinterpret_cast<TensorWithLayout*>(data)->global_shape().size();
493 }
494
TensorWithLayoutDim(void * data,int dim_index,TF_Status * status)495 int64_t TensorWithLayoutDim(void* data, int dim_index, TF_Status* status) {
496 return reinterpret_cast<TensorWithLayout*>(data)->global_shape()[dim_index];
497 }
498
TensorWithLayoutDeallocator(void * data)499 void TensorWithLayoutDeallocator(void* data) {
500 delete reinterpret_cast<TensorWithLayout*>(data);
501 }
502
TensorWithLayoutSummarize(void * data,TF_Status * status)503 TF_Buffer* TensorWithLayoutSummarize(void* data, TF_Status* status) {
504 std::string summary =
505 reinterpret_cast<TensorWithLayout*>(data)->SummarizeValue();
506 return TF_NewBufferFromString(summary.data(), summary.size());
507 }
508
MakeLayoutTensorHandle(TFE_Context * context,std::unique_ptr<TensorWithLayout> t,TF_Status * status)509 TFE_TensorHandle* DTensorDevice::MakeLayoutTensorHandle(
510 TFE_Context* context, std::unique_ptr<TensorWithLayout> t,
511 TF_Status* status) {
512 TF_DataType dtype = t->dtype();
513 TFE_CustomDeviceTensorHandleMethods handle_methods;
514 handle_methods.num_dims = &TensorWithLayoutNumDims;
515 handle_methods.dim = &TensorWithLayoutDim;
516 handle_methods.deallocator = &TensorWithLayoutDeallocator;
517 handle_methods.summarize = &TensorWithLayoutSummarize;
518 return TFE_NewCustomDeviceTensorHandle(context, name_.c_str(), dtype,
519 /*data=*/t.release(), handle_methods,
520 status);
521 }
522
RecordInShapeLayoutCache(const TensorWithLayout & tensor)523 void DTensorDevice::RecordInShapeLayoutCache(const TensorWithLayout& tensor) {
524 auto existing = shape_layout_cache_.insert(
525 {FingerprintShape(tensor.global_shape()),
526 CachedLayout{tensor.layout(), /*is_unique=*/true}});
527
528 if (!existing.second) {
529 // There is an entry already; if the layout doesn't match we should record
530 // the fact that it's not unique.
531 if (tensor.layout() != existing.first->second.layout) {
532 existing.first->second.is_unique = false;
533 }
534 }
535 }
536
is_remote_mesh(const Mesh & mesh) const537 bool DTensorDevice::is_remote_mesh(const Mesh& mesh) const {
538 // An empty mesh might be assigned to VarHandleOp during DTensor MLIR lowering
539 // pass. Decide whether the empty mesh is remote based on the current default
540 // mesh.
541 return mesh.is_remote() ||
542 (mesh.IsEmpty() && default_mesh_->mesh_config().is_remote());
543 }
544
FetchAttributes(const TFE_OpAttrs * attributes)545 StatusOr<NameAttrList> FetchAttributes(const TFE_OpAttrs* attributes) {
546 // TODO(allenl): Should we just give up on the public C API to save on
547 // serialization/deserialization? We need all of the attributes and to treat
548 // them generically, which isn't going to be pleasant with typed attribute
549 // methods.
550 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> serialized_attributes(
551 TF_NewBuffer(), TF_DeleteBuffer);
552
553 TF_Status* status = TF_NewStatus();
554 TFE_OpAttrsSerialize(attributes, serialized_attributes.get(), status);
555 if (TF_GetCode(status) == TF_OK) {
556 TF_DeleteStatus(status);
557 } else {
558 Status failure_status = StatusFromTF_Status(status);
559 TF_DeleteStatus(status);
560 return failure_status;
561 }
562
563 NameAttrList name_and_attrs;
564 if (!name_and_attrs.ParseFromArray(serialized_attributes->data,
565 serialized_attributes->length)) {
566 return tensorflow::errors::Unknown("Could not parse attributes");
567 }
568 return name_and_attrs;
569 }
570
FetchLayoutFromAttributes(const TFE_OpAttrs * attributes,absl::string_view attribute_name)571 StatusOr<Layout> FetchLayoutFromAttributes(const TFE_OpAttrs* attributes,
572 absl::string_view attribute_name) {
573 // Get attributes.
574 TF_ASSIGN_OR_RETURN(NameAttrList name_and_attrs, FetchAttributes(attributes));
575
576 // Get layout string from attributes.
577 absl::string_view layout_str =
578 name_and_attrs.attr().find(std::string(attribute_name))->second.s();
579
580 // This would probably be slow at the moment without caching.
581 // We should consider making this faster in the future.
582 return Layout::FromString(string(layout_str));
583 }
584
FetchLayout(TFE_Context * context,TFE_TensorHandle * input,TF_Status * status)585 std::string DTensorDevice::FetchLayout(TFE_Context* context,
586 TFE_TensorHandle* input,
587 TF_Status* status) {
588 VLOG(1) << "Checking layout...";
589 const char* input_device = TFE_TensorHandleDeviceName(input, status);
590 if (input_device != name_) {
591 TF_SetStatus(status, TF_INVALID_ARGUMENT,
592 "FetchLayout expects a tensor placed on the layout device.");
593 return {};
594 }
595 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
596 TFE_TensorHandleDevicePointer(input, status));
597 if (TF_GetCode(status) != TF_OK) return {};
598 return t->layout().ToString();
599 }
600
Unpack(TFE_Context * context,TFE_TensorHandle * input,TF_Status * status)601 std::vector<TFE_TensorHandle*> DTensorDevice::Unpack(TFE_Context* context,
602 TFE_TensorHandle* input,
603 TF_Status* status) {
604 std::vector<TFE_TensorHandle*> outputs;
605
606 const char* input_device = TFE_TensorHandleDeviceName(input, status);
607 if (TF_GetCode(status) != TF_OK) return outputs;
608 if (input_device != name_) {
609 TF_SetStatus(
610 status, TF_INVALID_ARGUMENT,
611 absl::StrCat(
612 "DTensorUnpack expects a tensor placed on the DTensor device: ",
613 name_, ", but input was placed on device: ", input_device)
614 .c_str());
615 return outputs;
616 }
617 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
618 TFE_TensorHandleDevicePointer(input, status));
619 if (TF_GetCode(status) != TF_OK) return outputs;
620
621 if (is_remote_mesh(t->mesh().mesh_config())) {
622 TF_SetStatus(status, TF_UNIMPLEMENTED,
623 "DTensorUnpack is not supported on a remote mesh.");
624 return outputs;
625 }
626 const int output_size = t->num_tensors();
627 outputs.resize(output_size);
628
629 for (int output_index = 0; output_index < output_size; ++output_index) {
630 outputs[output_index] =
631 TFE_TensorHandleCopySharingTensor(t->get_tensor(output_index), status);
632 if (TF_GetCode(status) != TF_OK) {
633 return outputs;
634 }
635 }
636 return outputs;
637 }
638
MaybeHandleDTensorCustomOps(const char * operation_name,const int num_inputs,const TFE_OpAttrs * attributes,TFE_Context * context,TFE_TensorHandle ** inputs,int * num_outputs,TFE_TensorHandle ** outputs,bool * is_custom_dtensor_op,TF_Status * status)639 void DTensorDevice::MaybeHandleDTensorCustomOps(
640 const char* operation_name, const int num_inputs,
641 const TFE_OpAttrs* attributes, TFE_Context* context,
642 TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs,
643 bool* is_custom_dtensor_op, TF_Status* status) {
644 *is_custom_dtensor_op = true;
645 if (operation_name == std::string("_EagerConst")) {
646 // Op-by-op const has no obvious layout. DTensor skips an SPMD expansion and
647 // instead relies on copy-on when the value is used later.
648 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
649 TFE_NewOp(context, operation_name, status), TFE_DeleteOp);
650 if (TF_GetCode(status) != TF_OK) return;
651 for (int input_index = 0; input_index < num_inputs; ++input_index) {
652 TFE_OpAddInput(op.get(), inputs[input_index], status);
653 if (TF_GetCode(status) != TF_OK) return;
654 }
655 TFE_OpAddAttrs(op.get(), attributes);
656 TFE_Execute(op.get(), outputs, num_outputs, status);
657 return;
658 }
659 if (operation_name == std::string("CopyToMesh")) {
660 CopyToMesh(context, num_inputs, inputs, attributes, outputs, num_outputs,
661 status);
662 return;
663 }
664
665 *is_custom_dtensor_op = false;
666 }
667
CopyToMesh(TFE_Context * context,int num_inputs,TFE_TensorHandle ** inputs,const TFE_OpAttrs * attributes,TFE_TensorHandle ** outputs,int * num_outputs,TF_Status * status)668 void DTensorDevice::CopyToMesh(TFE_Context* context, int num_inputs,
669 TFE_TensorHandle** inputs,
670 const TFE_OpAttrs* attributes,
671 TFE_TensorHandle** outputs, int* num_outputs,
672 TF_Status* status) {
673 if (num_inputs != 1) {
674 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
675 "DTensor CopyToMesh requires exactly 1 input.");
676 }
677 if (*num_outputs < 1) {
678 RETURN_STATUS(status, TF_INTERNAL,
679 "DTensor CopyToMesh must have output buffer to allocate at "
680 "least 1 output.");
681 }
682
683 // Assign layout.
684 StatusOr<Layout> target_layout_or =
685 FetchLayoutFromAttributes(attributes, kQualifiedLayoutAttr);
686 if (!target_layout_or.ok()) {
687 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
688 "DTensor CopyToMesh requires valid layout attribute for "
689 "destination DTensor.");
690 }
691
692 const Layout target_layout = *target_layout_or;
693 const Mesh& target_mesh = target_layout.mesh();
694
695 // TODO(b/193443769): Support sharded layout for eager copy to mesh.
696 if (!target_layout.IsFullyReplicated()) {
697 RETURN_STATUS(status, TF_UNIMPLEMENTED,
698 "Target layout of DTensor CopyToMesh must be replicated. "
699 "Consider changing the target layout to replicated layout or "
700 "file a bug to the DTensor team (b/193443769).");
701 }
702
703 TFE_TensorHandle* input_tensor = inputs[0];
704
705 // Check that if input tensor is DTensor, then input layout of the DTensor
706 // must be replicated.
707 const char* input_device = TFE_TensorHandleDeviceName(input_tensor, status);
708 if (TF_GetCode(status) != TF_OK) return;
709
710 if (name_ == input_device) {
711 // Handle input which is on DTensor device already.
712 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
713 TFE_TensorHandleDevicePointer(input_tensor, status));
714 if (TF_GetCode(status) != TF_OK) return;
715
716 if (!t->layout().IsFullyReplicated())
717 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
718 "Input tensor to CopyToMesh must be replicated DTensor or "
719 "normal eager Tensor.");
720
721 // If input to CopyToMesh is a DTensor, we use the first local tensor as
722 // input tensor handle to invoke copy.
723 input_tensor = t->get_tensor(0);
724 }
725
726 auto it = mesh_to_device_map_.find(target_mesh);
727 if (it == mesh_to_device_map_.end()) {
728 RETURN_STATUS(
729 status, TF_INTERNAL,
730 "DTensor CopyToMesh target mesh is not registered. Meshes should be "
731 "automatically registered. Please file a bug. (component id: 833864)");
732 }
733
734 const MeshWithParallelDevice* target_parallel_mesh = it->second.get();
735
736 // Broadcast non-dtensor value to dtensor.
737 std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast(
738 context, input_tensor, *target_parallel_mesh, name_, status);
739 if (TF_GetCode(status) != TF_OK) return;
740
741 RecordInShapeLayoutCache(*wrapper);
742 *num_outputs = 1;
743 *outputs = MakeLayoutTensorHandle(context, std::move(wrapper), status);
744 }
745
746 namespace {
747
748 // Verifies that all components have the same dtype and shape.
749 // The component shape will be set upon success.
VerifyPackTensorShapeAndDtype(std::vector<parallel_device::TensorHandlePtr> & components,std::vector<int64_t> * component_shape,TF_Status * status)750 void VerifyPackTensorShapeAndDtype(
751 std::vector<parallel_device::TensorHandlePtr>& components,
752 std::vector<int64_t>* component_shape, TF_Status* status) {
753 TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
754 auto size = TFE_TensorHandleNumDims(components[0].get(), status);
755 if (TF_GetCode(status) != TF_OK) return;
756 component_shape->clear();
757 component_shape->reserve(size);
758 for (int i = 0; i < size; ++i) {
759 component_shape->push_back(
760 TFE_TensorHandleDim(components[0].get(), i, status));
761 if (TF_GetCode(status) != TF_OK) return;
762 }
763
764 // Verify that the TensorHandle's shape and dtype match all of the component
765 // shapes and dtypes.
766 for (const auto& component : components) {
767 for (int i = 0; i < component_shape->size(); ++i) {
768 int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
769 if (TF_GetCode(status) != TF_OK) return;
770 if (tensor_dim != (*component_shape)[i]) {
771 TF_SetStatus(status, TF_UNIMPLEMENTED,
772 "Components of a PackedTensor must currently all have "
773 "the same shape");
774 return;
775 }
776 if (TFE_TensorHandleDataType(component.get()) != dtype) {
777 TF_SetStatus(status, TF_INTERNAL,
778 "Components of a PackedTensor must all have "
779 "the same dtype");
780 return;
781 }
782 }
783 }
784 }
785
786 // Verifies that all TensorHandles have rank `rank` of dtype `dtype`.
VerifyTensorRankAndDType(TFE_TensorHandle ** tensors,int num_input,int expected_rank,TF_DataType * expected_dtype,TF_Status * status)787 void VerifyTensorRankAndDType(TFE_TensorHandle** tensors, int num_input,
788 int expected_rank, TF_DataType* expected_dtype,
789 TF_Status* status) {
790 for (int i = 0; i < num_input; ++i) {
791 auto actual_rank = TFE_TensorHandleNumDims(tensors[i], status);
792 if (TF_GetCode(status) != TF_OK)
793 RETURN_STATUS(status, TF_INTERNAL, "Error getting rank of tensor.");
794 if (actual_rank != expected_rank)
795 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
796 "Rank of tensor did not match the expected rank.");
797 if (expected_dtype != nullptr &&
798 TFE_TensorHandleDataType(tensors[i]) != *expected_dtype)
799 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
800 "Dtype of tensor did not match the expected dtype.");
801 }
802 }
803
804 } // namespace
805
Pack(TFE_Context * context,int num_inputs,TFE_TensorHandle ** inputs,const std::string & string_layout,TF_Status * status)806 TFE_TensorHandle* DTensorDevice::Pack(TFE_Context* context, int num_inputs,
807 TFE_TensorHandle** inputs,
808 const std::string& string_layout,
809 TF_Status* status) {
810 if (num_inputs < 1) {
811 TF_SetStatus(status, TF_INVALID_ARGUMENT,
812 "DTensorPack requires 1 or more inputs");
813 return nullptr;
814 }
815 StatusOr<Layout> target_layout = Layout::FromString(string_layout);
816 if (!target_layout.ok()) {
817 TF_SetStatus(status, TF_INVALID_ARGUMENT,
818 "Failed to parse layout from string layout");
819 return nullptr;
820 }
821 const Mesh& target_mesh = target_layout->mesh();
822 const MeshWithParallelDevice* target_parallel_device =
823 mesh_to_device_map_[target_mesh].get();
824 if (target_parallel_device == nullptr) {
825 TF_SetStatus(status, TF_INVALID_ARGUMENT,
826 absl::StrCat("Required mesh : ", target_mesh.ToString(),
827 "is not registered with DTensor ")
828 .c_str());
829 return nullptr;
830 }
831
832 std::unique_ptr<TensorWithLayout> packed_tensor;
833 if (is_remote_mesh(target_parallel_device->mesh_config())) {
834 // Create a dummy output for DTensorPack if inputs are on a remote mesh.
835 TF_DataType dtype = TFE_TensorHandleDataType(inputs[0]);
836 auto size = TFE_TensorHandleNumDims(inputs[0], status);
837 if (TF_GetCode(status) != TF_OK) return nullptr;
838 std::vector<int64_t> component_shape;
839 component_shape.reserve(size);
840 for (int i = 0; i < size; ++i) {
841 component_shape.push_back(TFE_TensorHandleDim(inputs[0], i, status));
842 if (TF_GetCode(status) != TF_OK) return nullptr;
843 }
844 packed_tensor = TensorWithLayout::Dummy(
845 component_shape, dtype, *target_parallel_device, *target_layout);
846
847 } else {
848 auto local_devices = target_parallel_device->mesh_config().local_devices();
849
850 if (num_inputs !=
851 target_parallel_device->parallel_device().num_underlying_devices()) {
852 TF_SetStatus(status, TF_INVALID_ARGUMENT,
853 absl::StrCat("The dtensor device ", name_, " expected ",
854 local_devices.size(),
855 " inputs to DTensorPack, but got ", num_inputs)
856 .c_str());
857 return nullptr;
858 }
859
860 std::vector<parallel_device::TensorHandlePtr> components;
861 components.reserve(num_inputs);
862 for (int i = 0; i < num_inputs; ++i) {
863 TFE_TensorHandle* input = inputs[i];
864 const char* input_device = TFE_TensorHandleDeviceName(input, status);
865 if (TF_GetCode(status) != TF_OK) return nullptr;
866 if (name_ == input_device) {
867 TF_SetStatus(status, TF_INVALID_ARGUMENT,
868 "Does not support packing a Tensor that is already on "
869 "dtensor device");
870 return nullptr;
871 }
872 // If `input` is on the target device, this creates a new handle sharing
873 // the underlying data; otherwise, async copies are invoked.
874 components.emplace_back(TFE_TensorHandleCopyToDevice(
875 input, context, local_devices[i].c_str(), status));
876 if (TF_GetCode(status) != TF_OK) return nullptr;
877 }
878
879 std::vector<int64_t> component_shape;
880 VerifyPackTensorShapeAndDtype(components, &component_shape, status);
881 if (TF_GetCode(status) != TF_OK) return nullptr;
882
883 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
884 parallel_device::ParallelTensor::FromTensorHandles(
885 target_parallel_device->parallel_device(), std::move(components),
886 status);
887 if (TF_GetCode(status) != TF_OK) return nullptr;
888
889 if (target_layout->rank() != component_shape.size()) {
890 TF_SetStatus(
891 status, TF_INVALID_ARGUMENT,
892 absl::StrCat(
893 "Packed layout should have the same rank as the rank for each "
894 "component. The rank of each component is: ",
895 component_shape.size(),
896 ", while layout has rank: ", target_layout->rank(),
897 "\nLayout: ", target_layout->ToString(), "\n")
898 .c_str());
899 return nullptr;
900 }
901
902 packed_tensor =
903 TensorWithLayout::Wrap(std::move(parallel_tensor),
904 *target_parallel_device, *target_layout)
905 .ValueOrDie();
906 }
907
908 RecordInShapeLayoutCache(*packed_tensor);
909 TFE_TensorHandle* output =
910 MakeLayoutTensorHandle(context, std::move(packed_tensor), status);
911 if (TF_GetCode(status) != TF_OK) return nullptr;
912 return output;
913 }
914
SparsePack(TFE_Context * context,int num_inputs,TFE_TensorHandle ** indices,TFE_TensorHandle ** values,TFE_TensorHandle ** shapes,const std::string & string_layout,TF_Status * status)915 TFE_TensorHandle* DTensorDevice::SparsePack(
916 TFE_Context* context, int num_inputs, TFE_TensorHandle** indices,
917 TFE_TensorHandle** values, TFE_TensorHandle** shapes,
918 const std::string& string_layout, TF_Status* status) {
919 StatusOr<Layout> target_layout = Layout::FromString(string_layout);
920 if (!target_layout.ok()) {
921 TF_SetStatus(status, TF_INVALID_ARGUMENT,
922 "Failed to parse layout from string layout");
923 return nullptr;
924 }
925 const Mesh& target_mesh = target_layout->mesh();
926 const MeshWithParallelDevice* target_parallel_device =
927 mesh_to_device_map_[target_mesh].get();
928 if (target_parallel_device == nullptr) {
929 TF_SetStatus(status, TF_INVALID_ARGUMENT,
930 absl::StrCat("Required mesh : ", target_mesh.ToString(),
931 "is not registered with DTensor ")
932 .c_str());
933 return nullptr;
934 }
935
936 TF_DataType tf_int64 = TF_INT64;
937 // Verify rank and dtype of shapes.
938 VerifyTensorRankAndDType(shapes, num_inputs, /*expected_rank=*/1,
939 /*expected_dtype=*/&tf_int64, status);
940 if (TF_GetCode(status) != TF_OK) return nullptr;
941
942 // Verify rank and dtype of indices.
943 VerifyTensorRankAndDType(indices, num_inputs, /*expected_rank=*/2,
944 /*expected_dtype=*/&tf_int64, status);
945 if (TF_GetCode(status) != TF_OK) return nullptr;
946
947 // Verify rank of values.
948 VerifyTensorRankAndDType(values, num_inputs, /*expected_rank=*/1,
949 /*expected_dtype=*/nullptr, status);
950 if (TF_GetCode(status) != TF_OK) return nullptr;
951
952 // Compute the local shape from a shape tensor.
953 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> shape_tensor(
954 TFE_TensorHandleResolve(shapes[0], status), TF_DeleteTensor);
955 if (TF_GetCode(status) != TF_OK) {
956 TF_SetStatus(
957 status, TF_GetCode(status),
958 absl::StrCat("Error resolving the tensor handle of shape tensor"
959 ". Original message: ",
960 TF_Message(status))
961 .c_str());
962 return nullptr;
963 }
964 int shape_tensor_size = TFE_TensorHandleDim(shapes[0], 0, status);
965 if (TF_GetCode(status) != TF_OK || shape_tensor_size <= 0) {
966 TF_SetStatus(status, TF_GetCode(status),
967 absl::StrCat("Error computing the num dims of shape tensor",
968 TF_Message(status))
969 .c_str());
970 return nullptr;
971 }
972
973 const int64_t* data =
974 static_cast<int64_t*>(TF_TensorData(shape_tensor.get()));
975 std::vector<int64_t> local_shape(data, data + shape_tensor_size);
976 if (local_shape.size() != target_layout->rank()) {
977 TF_SetStatus(
978 status, TF_INVALID_ARGUMENT,
979 absl::StrCat(
980 "Packed layout should have the same rank as the rank for each "
981 "component. The rank of each component is: ",
982 local_shape.size(),
983 ", while layout has rank: ", target_layout->rank(),
984 "\nLayout: ", target_layout->ToString(), "\n")
985 .c_str());
986 return nullptr;
987 }
988
989 // Create the SparseTensorWithLayout.
990 std::unique_ptr<TensorWithLayout> packed_tensor;
991 if (is_remote_mesh(target_parallel_device->mesh_config())) {
992 // Create a dummy SparseTensorWithLayout.
993 packed_tensor = SparseTensorWithLayout::Dummy(
994 local_shape, *target_parallel_device, target_layout.ValueOrDie());
995 } else {
996 // Parse the indices, values, and dense_shape tensors and put them into
997 // parallel tensors, and then pack it into a single SparseTensorWithLayout.
998 auto local_devices = target_parallel_device->mesh_config().local_devices();
999
1000 std::vector<parallel_device::TensorHandlePtr> indices_components;
1001 std::vector<parallel_device::TensorHandlePtr> values_components;
1002 std::vector<parallel_device::TensorHandlePtr> dense_shapes_components;
1003
1004 // Just a nice trick to make code cleaner to pack each of indices, values,
1005 // shapes.
1006 std::vector<std::vector<parallel_device::TensorHandlePtr>*> components{
1007 &indices_components, &values_components, &dense_shapes_components};
1008 std::vector<TFE_TensorHandle**> input_vectors{indices, values, shapes};
1009 for (int component_index = 0; component_index < 3; ++component_index) {
1010 components[component_index]->reserve(num_inputs);
1011 TFE_TensorHandle** inputs = input_vectors[component_index];
1012 for (int i = 0; i < num_inputs; ++i) {
1013 const char* input_device =
1014 TFE_TensorHandleDeviceName(inputs[i], status);
1015 if (TF_GetCode(status) != TF_OK) return nullptr;
1016 if (name_ == input_device) {
1017 TF_SetStatus(status, TF_INVALID_ARGUMENT,
1018 "Does not support packing a Tensor that is already on "
1019 "dtensor device.");
1020 return nullptr;
1021 }
1022
1023 components[component_index]->emplace_back(TFE_TensorHandleCopyToDevice(
1024 inputs[i], context, local_devices[i].c_str(), status));
1025 if (TF_GetCode(status) != TF_OK) return nullptr;
1026 }
1027 }
1028 std::unique_ptr<parallel_device::ParallelTensor> parallel_indices_tensor =
1029 parallel_device::ParallelTensor::FromTensorHandles(
1030 target_parallel_device->parallel_device(),
1031 std::move(indices_components), status);
1032
1033 std::unique_ptr<parallel_device::ParallelTensor> parallel_values_tensor =
1034 parallel_device::ParallelTensor::FromTensorHandles(
1035 target_parallel_device->parallel_device(),
1036 std::move(values_components), status);
1037
1038 std::unique_ptr<parallel_device::ParallelTensor>
1039 parallel_dense_shapes_tensor =
1040 parallel_device::ParallelTensor::FromTensorHandles(
1041 target_parallel_device->parallel_device(),
1042 std::move(dense_shapes_components), status);
1043
1044 if (TF_GetCode(status) != TF_OK) return nullptr;
1045 packed_tensor =
1046 SparseTensorWithLayout::Wrap(std::move(parallel_indices_tensor),
1047 std::move(parallel_values_tensor),
1048 std::move(parallel_dense_shapes_tensor),
1049 *target_parallel_device,
1050 target_layout.ValueOrDie(), local_shape)
1051 .ValueOrDie();
1052 }
1053
1054 RecordInShapeLayoutCache(*packed_tensor);
1055 TFE_TensorHandle* output =
1056 MakeLayoutTensorHandle(context, std::move(packed_tensor), status);
1057 if (TF_GetCode(status) != TF_OK) return nullptr;
1058 return output;
1059 }
1060
IsSparseDTensor(TFE_Context * context,TFE_TensorHandle * input,TF_Status * status)1061 bool DTensorDevice::IsSparseDTensor(TFE_Context* context,
1062 TFE_TensorHandle* input,
1063 TF_Status* status) {
1064 const char* input_device = TFE_TensorHandleDeviceName(input, status);
1065 if (input_device != name_) {
1066 TF_SetStatus(
1067 status, TF_INVALID_ARGUMENT,
1068 "DTensorSparseUnpack expects a tensor placed on the DTensor device.");
1069 return false;
1070 }
1071 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
1072 TFE_TensorHandleDevicePointer(input, status));
1073 if (TF_GetCode(status) != TF_OK) return false;
1074 return t->tensor_type() == TensorType::kSparse;
1075 }
1076
UpdateOutputLayoutsWithSameShapePolicy(const std::vector<PartialTensorShape> & global_output_shapes,const absl::flat_hash_set<Mesh> & input_meshes,absl::string_view op_name,tensorflow::Graph * graph,std::vector<const Layout * > * output_layouts,TF_Status * status)1077 void DTensorDevice::UpdateOutputLayoutsWithSameShapePolicy(
1078 const std::vector<PartialTensorShape>& global_output_shapes,
1079 const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name,
1080 tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts,
1081 TF_Status* status) {
1082 if (!same_shape_policy_enabled_) return;
1083 // Simply do not hint if inputs span across multiple meshes.
1084 if (input_meshes.size() > 1) return;
1085
1086 for (Node* node : graph->op_nodes()) {
1087 if (!node->IsRetval()) {
1088 continue;
1089 }
1090 int output_index;
1091 RETURN_C_STATUS_IF_NOT_OK(
1092 GetNodeAttr(node->attrs(), "index", &output_index), status);
1093 if (output_layouts->at(output_index)) {
1094 continue;
1095 }
1096
1097 const auto& global_output_shape = global_output_shapes.at(output_index);
1098 const Layout* layout = nullptr;
1099 // TODO(b/180022708): This is useful information, we should be
1100 // able to hint to layout propagation without making it a hard
1101 // requirement
1102 //
1103 // Special cases at the moment:
1104 // - Relayout needs an exemption.
1105 // - VarHandleOp does not need hint. VarHandleOp has scalar shape so layout
1106 // is trivial. On the other hande, downstream system "thinks' Variable has
1107 // shape same as the pointing value. So, providing a layout based on
1108 // VarHandleOp (scalar) might confuse the downstream system.
1109 if (op_name != std::string("Relayout") &&
1110 op_name != std::string("VarHandleOp")) {
1111 // TODO(b/162009702): Support matching between partially-known shapes.
1112 if (global_output_shape.IsFullyDefined()) {
1113 gtl::InlinedVector<int64, 4> shape_vector(
1114 global_output_shape.dim_sizes());
1115 auto layout_iterator =
1116 shape_layout_cache_.find(FingerprintShape(shape_vector));
1117 if (layout_iterator != shape_layout_cache_.end() &&
1118 layout_iterator->second.is_unique) {
1119 // We have a cached layout for this shape. Send it to MLIR.
1120 layout = &layout_iterator->second.layout;
1121 VLOG(3) << op_name << ": found a cached layout for shape "
1122 << global_output_shape.DebugString() << ": \""
1123 << layout->ToString() << "\"";
1124 if (input_meshes.empty() &&
1125 layout->mesh() != default_mesh_->mesh_config()) {
1126 VLOG(3) << "But we can't infer a input mesh and cached layout: "
1127 << "mesh \"" << (layout->mesh().ToString()) << " "
1128 << "is different than the default mesh : \""
1129 << default_mesh_->mesh_config().ToString() << "\"\n"
1130 << "Not applying the cached layout.";
1131 } else if (!input_meshes.empty() &&
1132 layout->mesh() != *input_meshes.begin()) {
1133 VLOG(3)
1134 << "But the layout mesh is different than the executing mesh: "
1135 << "\"" << (*input_meshes.begin()).ToString() << "\"\n"
1136 << "Not applying the cached layout.";
1137 } else {
1138 (*output_layouts)[output_index] = layout;
1139 node->AddAttr(kDefaultLayoutAttr, layout->ToString());
1140 }
1141 } else if (layout_iterator == shape_layout_cache_.end()) {
1142 VLOG(3) << op_name << ": no cached layout found for shape "
1143 << global_output_shape.DebugString();
1144 } else {
1145 VLOG(3) << op_name << ": found multiple layouts for shape "
1146 << global_output_shape.DebugString();
1147 }
1148 } else {
1149 VLOG(3) << op_name
1150 << ": not applying same-shape-same-layout due to "
1151 "not-fully-known shape "
1152 << global_output_shape.DebugString();
1153 }
1154 }
1155 }
1156 }
1157
1158 std::unordered_map<std::string, int>
GetFunctionCacheHitAndMissCount(TFE_Context * context,TF_Status * status) const1159 DTensorDevice::GetFunctionCacheHitAndMissCount(TFE_Context* context,
1160 TF_Status* status) const {
1161 return function_compilation_hits_and_misses_;
1162 }
1163
1164 // From `graph` containing computation for all meshes, extract/select
1165 // computation for mesh specified in `function`. Returned graph is a cloned
1166 // graph with ops only for single mesh execution.
SelectGraphToExecute(const TranslatedFunction & function,const Graph & graph,std::string * stateful_partitioned_call_name)1167 StatusOr<std::unique_ptr<Graph>> SelectGraphToExecute(
1168 const TranslatedFunction& function, const Graph& graph,
1169 std::string* stateful_partitioned_call_name) {
1170 auto new_graph = std::make_unique<Graph>(graph.flib_def());
1171 CopyGraph(graph, new_graph.get());
1172 std::vector<Node*> arg_nodes;
1173 std::vector<Node*> retval_nodes;
1174 for (Node* node : new_graph->nodes()) {
1175 if (node->IsArg()) arg_nodes.emplace_back(node);
1176 if (node->IsRetval()) retval_nodes.emplace_back(node);
1177 }
1178
1179 // Remove irrelevant function calls.
1180 for (Node* node : new_graph->nodes()) {
1181 if (node->op_def().name() != "StatefulPartitionedCall") continue;
1182
1183 if (node->name() != function.node_to_execute->name()) {
1184 // Remove function call that does not match mesh specification and all
1185 // output retval nodes connected to the function call node.
1186 std::queue<Node*> nodes_to_remove;
1187 nodes_to_remove.push(node);
1188 while (!nodes_to_remove.empty()) {
1189 Node* n = nodes_to_remove.front();
1190 for (const Edge* out_edge : n->out_edges()) {
1191 if (out_edge->IsControlEdge()) continue;
1192 Node* out_node = out_edge->dst();
1193 if (!out_node->IsSink()) nodes_to_remove.push(out_node);
1194 }
1195 if (n->IsRetval()) {
1196 auto pos = std::find(retval_nodes.begin(), retval_nodes.end(), n);
1197 TF_RET_CHECK(pos != retval_nodes.end());
1198 retval_nodes.erase(pos);
1199 }
1200 nodes_to_remove.pop();
1201 new_graph->RemoveNode(n);
1202 }
1203 }
1204 }
1205
1206 *stateful_partitioned_call_name = function.node_to_execute->name();
1207 VLOG(1) << "Selected call " << *stateful_partitioned_call_name;
1208
1209 // Remove unused arg nodes in graph.
1210 for (auto it = arg_nodes.begin(); it != arg_nodes.end(); it++) {
1211 Node* arg_node = *it;
1212 bool arg_unused = true;
1213 for (const Edge* e : arg_node->out_edges()) {
1214 if (e->dst()->IsOp()) {
1215 arg_unused = false;
1216 }
1217 }
1218 if (!arg_unused) continue;
1219
1220 new_graph->RemoveNode(arg_node);
1221 arg_nodes.erase(it--);
1222 }
1223
1224 // Reset index attributes for arg and retval nodes.
1225 for (Node* n : new_graph->nodes()) {
1226 // Reset arg node index attributes to its position within all the arg
1227 // nodes. This should just be increasing from 0 to n where n
1228 // is the total number of arguments. Note that this definition to
1229 // the `index` attribute is different from the definition we set in
1230 // PrepareGraphForMLIR.
1231 // This attribute is needed for each arg node when converting a Graph to
1232 // a FunctionDef.
1233 if (n->IsArg()) {
1234 auto pos = std::find(arg_nodes.begin(), arg_nodes.end(), n);
1235 TF_RET_CHECK(pos != arg_nodes.end());
1236 const int new_index = std::distance(arg_nodes.begin(), pos);
1237 n->AddAttr("index", new_index);
1238 }
1239
1240 // Reset retval nodes index attributes.
1241 if (n->IsRetval()) {
1242 auto retval_pos = std::find(retval_nodes.begin(), retval_nodes.end(), n);
1243 TF_RET_CHECK(retval_pos != retval_nodes.end());
1244 const int new_index = std::distance(retval_nodes.begin(), retval_pos);
1245 n->AddAttr("index", new_index);
1246 }
1247 }
1248
1249 VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_to_execute_",
1250 *new_graph);
1251
1252 return new_graph;
1253 }
1254
1255 // Adds processed graph to run for each mesh computation in
1256 // `execution_functions` to function definition library.
AddExecutionFunctionDefsToFunctionDefLibrary(const absl::flat_hash_set<Node * > & control_ret_nodes,TFE_Context * context,const Graph & graph,ExecutionFunctions * execution_functions)1257 Status AddExecutionFunctionDefsToFunctionDefLibrary(
1258 const absl::flat_hash_set<Node*>& control_ret_nodes, TFE_Context* context,
1259 const Graph& graph, ExecutionFunctions* execution_functions) {
1260 // Note: We use node name instead of node pointer for comparison because
1261 // node address in the new graph is different with the original graph.
1262 absl::flat_hash_set<std::string> control_ret_names;
1263 for (auto* n : control_ret_nodes) {
1264 control_ret_names.emplace(n->name());
1265 }
1266 for (TranslatedFunction& function : execution_functions->function_list) {
1267 std::string selected_call_node_name;
1268 // TODO(bfontain): We should just try to call the functions directly rather
1269 // than wrap
1270 // Construct graph that executes only computation for `function`.
1271 TF_ASSIGN_OR_RETURN(
1272 std::unique_ptr<Graph> new_graph,
1273 SelectGraphToExecute(function, graph, &selected_call_node_name));
1274 VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_", *new_graph);
1275
1276 // Add unique identifier based on the function we are executing to the
1277 // function/graph and convert graph to functiondef.
1278 NameAttrList func;
1279 TF_RETURN_IF_ERROR(
1280 GetNodeAttr(function.node_to_execute->attrs(), "f", &func));
1281
1282 static std::atomic<int64_t> unique_function_number(0);
1283 function.translated_function_name =
1284 absl::StrCat(func.name(), "_", unique_function_number.fetch_add(1));
1285 auto control_ret_node_names =
1286 [&control_ret_names, &selected_call_node_name](
1287 const Node* node) -> absl::optional<std::string> {
1288 // Add the stateful partitioned call node as a control return as we need
1289 // to process any control deps inside the inner function.
1290 if (control_ret_names.contains(node->name()) ||
1291 node->name() == selected_call_node_name) {
1292 return node->name();
1293 }
1294 return absl::nullopt;
1295 };
1296
1297 tensorflow::FunctionDef to_run;
1298 TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef(
1299 *new_graph, function.translated_function_name, control_ret_node_names,
1300 &to_run));
1301
1302 for (const auto& out : to_run.signature().output_arg()) {
1303 function.output_dtypes.emplace_back(static_cast<TF_DataType>(out.type()));
1304 }
1305
1306 AddDTensorFunctionAttr(to_run);
1307 TF_RETURN_IF_ERROR(tensorflow::unwrap(context)->AddFunctionDef(to_run));
1308 }
1309
1310 return OkStatus();
1311 }
1312
LowerToSPMDFunction(TFE_Context * context,const std::vector<TensorWithLayout * > & inputs,const DTensorOperation & doperation,const TFE_OpAttrs * attributes,const int num_outputs,const ExecutionFunctions ** execution_functions,TF_Status * status)1313 void DTensorDevice::LowerToSPMDFunction(
1314 TFE_Context* context, const std::vector<TensorWithLayout*>& inputs,
1315 const DTensorOperation& doperation, const TFE_OpAttrs* attributes,
1316 const int num_outputs, const ExecutionFunctions** execution_functions,
1317 TF_Status* status) {
1318 profiler::TraceMe activity(
1319 [&] { return "DTensorDevice::LowerToSPMDFunction"; },
1320 profiler::TraceMeLevel::kInfo);
1321 FunctionLibraryDefinition* flib_def =
1322 tensorflow::unwrap(context)->FuncLibDef();
1323 auto graph(std::make_unique<tensorflow::Graph>(flib_def));
1324 NameAttrList eager_attributes;
1325 ASSIGN_OR_RETURN_C_STATUS(eager_attributes, FetchAttributes(attributes),
1326 status);
1327
1328 std::vector<PartialTensorShape> global_output_shapes;
1329 std::vector<const Layout*> output_layouts;
1330 const FunctionDef* function_def = doperation.function_def;
1331 if (!function_def) {
1332 // Output layouts of an eager op (e.g. fill) must be inferred before cache
1333 // key computation, since they might depend on the current DTensorDevice
1334 // state.
1335 Status s = PrepareGraphForMlir(
1336 function_manager_, inputs, doperation, *flib_def, eager_attributes,
1337 default_layout_, graph.get(), &global_output_shapes, &output_layouts);
1338 RETURN_C_STATUS_IF_NOT_OK(s, status);
1339
1340 // Finds all meshes the inputs are lied on.
1341 absl::flat_hash_set<Mesh> input_meshes;
1342 for (const TensorWithLayout* tensor : inputs) {
1343 if (!tensor->layout().mesh().IsEmpty()) {
1344 input_meshes.insert(tensor->layout().mesh());
1345 }
1346 }
1347 // Currently we only provide layout hints for op-by-op, since
1348 // they interact badly with layout propagation.
1349 UpdateOutputLayoutsWithSameShapePolicy(global_output_shapes, input_meshes,
1350 doperation.name, graph.get(),
1351 &output_layouts, status);
1352 if (TF_GetCode(status) != TF_OK) return;
1353 }
1354
1355 std::pair<tensorflow::Fprint128, const ExecutionFunctions*>
1356 cache_key_and_func = function_manager_.GetCachedFunction(
1357 doperation, eager_attributes, inputs, output_layouts);
1358 *execution_functions = cache_key_and_func.second;
1359 if (*execution_functions != nullptr) {
1360 function_compilation_hits_and_misses_["hit"]++;
1361 return;
1362 } else if (function_def) {
1363 function_compilation_hits_and_misses_["miss"]++;
1364 LOG(INFO) << "DTensor cache key lookup missed for " << doperation.name
1365 << ". DTensor is (re-)computing its SPMD transformation.";
1366 }
1367
1368 // It includes remote devices when the coordination service is enabled.
1369 const auto device_list = tensorflow::unwrap(context)->ListAllTfDevices();
1370 DeviceSet device_set;
1371 for (const auto device : device_list) device_set.AddDevice(device);
1372
1373 if (function_def) {
1374 ASSIGN_OR_RETURN_C_STATUS(auto device_name_to_mesh_device,
1375 PipelineSubMeshes(context), status);
1376 const bool is_pipelining_function = !device_name_to_mesh_device.empty();
1377 // For a multi-mesh function for pipelining, we take a different execution
1378 // path. Call the partitioner to lower and partition the graph into multiple
1379 // sub functions to execute (one per sub mesh).
1380 if (is_pipelining_function) {
1381 ASSIGN_OR_RETURN_C_STATUS(
1382 ExecutionFunctions functions,
1383 PipeliningPartitionerRun(&device_name_to_mesh_device, flib_def,
1384 &pass_runner_, *doperation.function_def,
1385 eager_attributes, inputs, device_set,
1386 num_outputs),
1387 status);
1388 *execution_functions = function_manager_.AddCachedFunction(
1389 doperation, cache_key_and_func.first, std::move(functions));
1390 return;
1391 }
1392 // Output layouts of a function are inferred by MLIR lowering. They are
1393 // not necessary for cache key computation, so run PrepareGraphForMlir after
1394 // cache key computation to reduce the overheads of running the same
1395 // function multiple times.
1396 Status s = PrepareGraphForMlir(
1397 function_manager_, inputs, doperation, *flib_def, eager_attributes,
1398 default_layout_, graph.get(), &global_output_shapes, &output_layouts);
1399 RETURN_C_STATUS_IF_NOT_OK(s, status);
1400 }
1401
1402 absl::flat_hash_set<Node*> control_ret_nodes;
1403 // Run DTensor MLIR passes that convert input graph to SPMD version.
1404 {
1405 profiler::TraceMe activity([&] { return "DTensorDevice::RunMLIRPasses"; },
1406 profiler::TraceMeLevel::kInfo);
1407 RETURN_C_STATUS_IF_NOT_OK(
1408 pass_runner_.RunOnGraph(device_set, doperation.is_func(), flib_def,
1409 &graph, control_ret_nodes,
1410 cache_key_and_func.first),
1411 status);
1412 }
1413 VLOG(4) << tensorflow::DumpGraphToFile("after_mlir_spmd_lowering", *graph,
1414 flib_def);
1415 if (flib_def->Contains(kLoadEmbeddingFn)) {
1416 Status s = InsertFunctionForTPUEmbeddingCheckpoint(
1417 status, graph.get(), inputs, kLoadEmbeddingFn);
1418 RETURN_C_STATUS_IF_NOT_OK(s, status);
1419 }
1420
1421 // After MLIR transformations, exactly one StatefulPartitionedCall op is
1422 // returned for mesh cluster in computation. Identity all functions to execute
1423 // for each mesh and relevant input and output information.
1424 ASSIGN_OR_RETURN_C_STATUS(
1425 ExecutionFunctions functions,
1426 IdentifyAllFunctionsToExecute(*graph.get(), global_output_shapes),
1427 status);
1428
1429 // In order to ensure that all resource assign operations as well as side
1430 // effecting ops are executed, we add identity ops before function outputs
1431 // with control rets.
1432 RETURN_C_STATUS_IF_NOT_OK(MaybeInsertIdentityNodes(function_def, graph.get()),
1433 status);
1434
1435 VLOG(4) << tensorflow::DumpGraphToFile("after_post_processing_graph", *graph,
1436 flib_def);
1437
1438 RETURN_C_STATUS_IF_NOT_OK(
1439 AddExecutionFunctionDefsToFunctionDefLibrary(control_ret_nodes, context,
1440 *graph.get(), &functions),
1441 status);
1442 functions.num_device_ids = 1;
1443 if (function_def) {
1444 for (TranslatedFunction& function : functions.function_list) {
1445 functions.function_mesh_fingerprint =
1446 FingerprintCat64(functions.function_mesh_fingerprint,
1447 function.function_mesh.GlobalFingerprint());
1448 }
1449 }
1450
1451 *execution_functions = function_manager_.AddCachedFunction(
1452 doperation, cache_key_and_func.first, std::move(functions));
1453 }
1454
ExecuteFunctionAndWait(TFE_Context * context,const TranslatedFunction * function_ptr,const MeshWithParallelDevice * parallel_device_mesh,const std::vector<parallel_device::ParallelTensor * > & parallel_inputs,const int64_t step_id,const TFE_OpAttrs * attributes,TF_Status * status)1455 void DTensorDevice::ExecuteFunctionAndWait(
1456 TFE_Context* context, const TranslatedFunction* function_ptr,
1457 const MeshWithParallelDevice* parallel_device_mesh,
1458 const std::vector<parallel_device::ParallelTensor*>& parallel_inputs,
1459 const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status) {
1460 const std::string mesh_str = function_ptr->function_mesh.ToString();
1461 VLOG(4) << "Launching computation for mesh : " << mesh_str;
1462 parallel_device_mesh->parallel_device().StartExecute(
1463 context,
1464 /*inputs=*/parallel_inputs,
1465 /*operation_name=*/function_ptr->translated_function_name.c_str(),
1466 /*attributes=*/attributes,
1467 /*expected_max_outputs=*/function_ptr->local_output_shapes.size(),
1468 /*cancellation_manager=*/*cancellation_manager_,
1469 /*step_id=*/step_id);
1470
1471 VLOG(4) << "Joining computation result from mesh : " << mesh_str;
1472 parallel_device_mesh->parallel_device().Join(
1473 function_ptr->local_output_shapes, status);
1474 VLOG(4) << "Joining status: " << TF_Message(status);
1475 if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_CANCELLED) {
1476 LOG(ERROR) << "Encountered error while executing function: "
1477 << function_ptr->translated_function_name
1478 << " for mesh : " << mesh_str
1479 << " / error : " << TF_Message(status);
1480 }
1481
1482 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status(
1483 TF_NewStatus(), TF_DeleteStatus);
1484 AsyncWait(context, async_wait_status.get());
1485 TF_Code error_code = TF_GetCode(async_wait_status.get());
1486 if (error_code != TF_OK && error_code != TF_CANCELLED) {
1487 LOG(ERROR) << "Async status: " << TF_Message(async_wait_status.get());
1488 }
1489 }
1490
ExecuteRegularOperation(TFE_Context * context,const std::vector<TensorWithLayout * > & inputs,const DTensorOperation & doperation,const TFE_OpAttrs * attributes,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status)1491 void DTensorDevice::ExecuteRegularOperation(
1492 TFE_Context* context, const std::vector<TensorWithLayout*>& inputs,
1493 const DTensorOperation& doperation, const TFE_OpAttrs* attributes,
1494 int* num_outputs, TFE_TensorHandle** outputs, TF_Status* status) {
1495 const ExecutionFunctions* execution_functions = nullptr;
1496
1497 LowerToSPMDFunction(context, inputs, doperation, attributes, *num_outputs,
1498 &execution_functions, status);
1499 if (TF_GetCode(status) != TF_OK) return;
1500
1501 // Update input layouts for resource arguments.
1502 for (const TranslatedFunction& function :
1503 execution_functions->function_list) {
1504 for (const auto& entry : function.resource_input_layouts) {
1505 // TODO(hthu): Add an TensorWithLayout in the inputs vector at location 0
1506 // for DeviceId. This is done as the first arg is always DeviceId, and it
1507 // isn't mapped to input Tensors.
1508 const int resource_index_to_update = entry.first - 1;
1509 inputs[resource_index_to_update]->UpdateLayout(entry.second, status);
1510 if (TF_GetCode(status) != TF_OK) {
1511 RETURN_STATUS(status, TF_GetCode(status),
1512 absl::StrCat("Attempt to update layout input arg: ",
1513 resource_index_to_update,
1514 ". Original message: ", TF_Message(status))
1515 .c_str());
1516 }
1517 }
1518 }
1519
1520 int num_global_outputs = 0;
1521
1522 // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh
1523 // object. Ideally we'd just use a fingerprinted int64_t as a unique
1524 // identifier for a mesh.
1525 std::map<std::string, const MeshWithParallelDevice*>
1526 function_name_and_mesh_mapping;
1527 absl::flat_hash_set<std::string> excluded_fn_names;
1528 std::unique_ptr<const TranslatedFunction> epu_fn_ptr, load_embedding_ptr;
1529 for (const TranslatedFunction& function :
1530 execution_functions->function_list) {
1531 StatusOr<Mesh> maybe_converted_mesh = function.function_mesh;
1532 if (function.function_mesh.is_epu_mesh()) {
1533 maybe_converted_mesh = function.function_mesh.ToDeviceType("CPU");
1534 }
1535
1536 if (!maybe_converted_mesh.ok()) {
1537 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
1538 absl::StrCat("Failed to convert mesh, get error: ",
1539 maybe_converted_mesh.status().error_message())
1540 .c_str());
1541 }
1542 const Mesh& mesh = *maybe_converted_mesh;
1543 // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh
1544 // object. Ideally we'd just use a fingerprinted int64_t as a unique
1545 // identifier for a mesh.
1546 const MeshWithParallelDevice* parallel_device_mesh =
1547 mesh_to_device_map_.contains(mesh) ? mesh_to_device_map_[mesh].get()
1548 : default_mesh_;
1549 if (parallel_device_mesh == nullptr) {
1550 RETURN_STATUS(status, TF_INTERNAL,
1551 "required mesh is not registered with DTensor device");
1552 }
1553 function_name_and_mesh_mapping[function.translated_function_name] =
1554 parallel_device_mesh;
1555
1556 if (function.function_mesh.is_epu_mesh()) {
1557 if (epu_fn_ptr != nullptr) {
1558 RETURN_STATUS(status, TF_INTERNAL,
1559 "There are more than one function defined on EPU mesh.");
1560 }
1561 epu_fn_ptr = std::make_unique<const TranslatedFunction>(function);
1562 excluded_fn_names.insert(function.translated_function_name);
1563 }
1564 if (absl::StartsWith(function.translated_function_name, kLoadEmbeddingFn)) {
1565 if (load_embedding_ptr != nullptr) {
1566 RETURN_STATUS(status, TF_INTERNAL,
1567 "There are more than one function defined on EPU mesh.");
1568 }
1569 load_embedding_ptr = std::make_unique<const TranslatedFunction>(function);
1570 excluded_fn_names.insert(function.translated_function_name);
1571 }
1572 }
1573
1574 // Compute the step_id based on the function_mesh_fingerprint and the
1575 // corresponding function execution counter.
1576 uint64 function_mesh_fingerprint =
1577 execution_functions->function_mesh_fingerprint;
1578 if (func_mesh_fingerprint_to_step_counter_.contains(
1579 function_mesh_fingerprint)) {
1580 func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint)++;
1581 } else {
1582 func_mesh_fingerprint_to_step_counter_.insert(
1583 {function_mesh_fingerprint, 0});
1584 }
1585 const uint64 step_id = FingerprintCat64(
1586 function_mesh_fingerprint,
1587 func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint));
1588
1589 // Execute excluded functions in sequence.
1590 if (epu_fn_ptr != nullptr) {
1591 ExecuteFunctionAndWait(
1592 context,
1593 /*function_ptr=*/epu_fn_ptr.get(),
1594 /*parallel_device_mesh=*/
1595 function_name_and_mesh_mapping[epu_fn_ptr->translated_function_name],
1596 /*parallel_inputs=*/{}, /*step_id=*/step_id, /*attributes=*/attributes,
1597 /*status=*/status);
1598 }
1599
1600 if (load_embedding_ptr != nullptr) {
1601 StatusOr<std::vector<parallel_device::ParallelTensor*>> parallel_inputs =
1602 PrepareEmbeddingInputs(inputs);
1603 if (!parallel_inputs.ok()) {
1604 RETURN_STATUS(status, TF_INTERNAL,
1605 parallel_inputs.status().error_message().c_str());
1606 }
1607 ExecuteFunctionAndWait(
1608 context,
1609 /*function_ptr=*/load_embedding_ptr.get(),
1610 /*parallel_device_mesh=*/
1611 function_name_and_mesh_mapping[load_embedding_ptr
1612 ->translated_function_name],
1613 /*parallel_inputs=*/*parallel_inputs, /*step_id=*/step_id,
1614 /*attributes=*/attributes, /*status=*/status);
1615 }
1616
1617 // Extract the global parallel inputs and flatten SparseTensors
1618 // into the three component tensors.
1619 std::vector<parallel_device::ParallelTensor*> global_parallel_inputs;
1620 std::vector<parallel_device::ParallelTensor*> global_parallel_sparse_inputs;
1621 absl::flat_hash_set<int> global_sparse_input_indices;
1622 for (auto input : inputs) {
1623 if (input->tensor_type() == TensorType::kSparse) {
1624 SparseTensorWithLayout* sparse_input =
1625 dynamic_cast<SparseTensorWithLayout*>(input);
1626 global_parallel_sparse_inputs.push_back(sparse_input->indices());
1627 global_parallel_sparse_inputs.push_back(sparse_input->dense_shapes());
1628 global_parallel_sparse_inputs.push_back(sparse_input->values());
1629 } else {
1630 global_parallel_inputs.push_back(input->tensor());
1631 }
1632 }
1633 // Insert SparseTensor components to the end, this is because
1634 // in the MLIR handling of SparseTensors, we place SparseTensor components
1635 // to the end of the main func arguments for a fixed ordering.
1636 global_parallel_inputs.insert(global_parallel_inputs.end(),
1637 global_parallel_sparse_inputs.begin(),
1638 global_parallel_sparse_inputs.end());
1639
1640 // Execute all functions in parallel.
1641 for (const TranslatedFunction& function :
1642 execution_functions->function_list) {
1643 const Mesh& mesh = function.function_mesh;
1644 const std::string& translated_function_name =
1645 function.translated_function_name;
1646
1647 num_global_outputs += function.local_output_shapes.size();
1648
1649 if (is_remote_mesh(mesh) ||
1650 (excluded_fn_names.find(translated_function_name) !=
1651 excluded_fn_names.end())) {
1652 // Skip execution for a translated function has remote mesh or when it is
1653 // excluded.
1654 continue;
1655 }
1656
1657 const MeshWithParallelDevice* parallel_device_mesh =
1658 function_name_and_mesh_mapping[translated_function_name];
1659
1660 // Gather the local inputs for this function.
1661 std::vector<parallel_device::ParallelTensor*> parallel_inputs;
1662 parallel_inputs.reserve(inputs.size() + 1);
1663 auto input_mapping = function.input_index_map;
1664
1665 // We sort here because by this time, the function graph we are executing
1666 // is a reduced version of the main function, that includes the
1667 // StatefulPartitionedCall that we are executing for this mesh.
1668 // Thus, the ordering is the same as the main function ordering, which
1669 // is sorted increasingly.
1670 std::sort(input_mapping.begin(), input_mapping.end());
1671
1672 for (const int global_index : input_mapping) {
1673 auto input_index = global_index - execution_functions->num_device_ids;
1674
1675 if (global_index < execution_functions->num_device_ids) {
1676 parallel_inputs.push_back(
1677 parallel_device_mesh->DeviceIDs(context, status));
1678 if (TF_GetCode(status) != TF_OK) return;
1679 } else {
1680 parallel_inputs.push_back(global_parallel_inputs[input_index]);
1681 }
1682 }
1683
1684 VLOG(4) << "Launching computation for mesh : " << mesh.ToString();
1685 parallel_device_mesh->parallel_device().StartExecute(
1686 context, parallel_inputs, translated_function_name.c_str(), attributes,
1687 /*expected_max_outputs=*/function.local_output_shapes.size(),
1688 *cancellation_manager_, /*step_id=*/step_id);
1689 }
1690
1691 *num_outputs = num_global_outputs;
1692 std::vector<std::unique_ptr<TensorWithLayout>> typed_outputs;
1693 typed_outputs.resize(num_global_outputs);
1694
1695 // Join all mesh computation together.
1696 // TODO(b/177932563): Expose cancel logic to handle failures.
1697 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> join_status(
1698 TF_NewStatus(), TF_DeleteStatus);
1699 for (const TranslatedFunction& function :
1700 execution_functions->function_list) {
1701 // Skip execution for a function when it's excluded.
1702 if (excluded_fn_names.contains(function.translated_function_name)) {
1703 continue;
1704 }
1705 const Mesh& mesh = function.function_mesh;
1706 // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh
1707 // object. Ideally we'd just use a fingerprinted int64_t as a unique
1708 // identifier for a mesh.
1709 const MeshWithParallelDevice* parallel_device_mesh =
1710 function_name_and_mesh_mapping[function.translated_function_name];
1711
1712 std::vector<std::unique_ptr<TensorWithLayout>> output_with_layout;
1713 output_with_layout.reserve(function.output_index_map.size());
1714 if (is_remote_mesh(mesh)) {
1715 // Create dummy outputs on a remote mesh.
1716 for (int i = 0; i < function.output_index_map.size(); ++i) {
1717 const auto dim_sizes = function.local_output_shapes.at(i).dim_sizes();
1718 std::vector<int64_t> local_shape =
1719 std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end());
1720 TF_DataType dtype =
1721 static_cast<TF_DataType>(function.output_dtypes.at(i));
1722 auto remote_output =
1723 TensorWithLayout::Dummy(local_shape, dtype, *parallel_device_mesh,
1724 function.output_layouts[i]);
1725 output_with_layout.push_back(std::move(remote_output));
1726 }
1727 } else {
1728 VLOG(4) << "Joining computation result from mesh : " << mesh.ToString();
1729 auto result = parallel_device_mesh->parallel_device().Join(
1730 function.local_output_shapes, status);
1731 if (TF_GetCode(join_status.get()) != TF_OK &&
1732 // Preserve the first failure we see, but only if it is a real failure
1733 // and not a cancellation (which was probably triggered by the error
1734 // we want to propagate).
1735 (TF_GetCode(status) == TF_OK ||
1736 TF_GetCode(join_status.get()) != TF_CANCELLED)) {
1737 continue;
1738 }
1739 if (TF_GetCode(status) != TF_OK) {
1740 if (TF_GetCode(status) != TF_CANCELLED) {
1741 LOG(ERROR) << "Encountered error while executing function: "
1742 << function.translated_function_name
1743 << " for mesh : " << mesh.ToString()
1744 << " / error : " << TF_Message(status);
1745 }
1746 TF_SetStatus(join_status.get(), TF_GetCode(status), TF_Message(status));
1747 continue;
1748 }
1749
1750 for (int i = 0; i < result->size(); ++i) {
1751 ASSIGN_OR_RETURN_C_STATUS(
1752 auto local_output,
1753 TensorWithLayout::Wrap(std::move((*result)[i]),
1754 *parallel_device_mesh,
1755 function.output_layouts[i]),
1756 status);
1757 output_with_layout.push_back(std::move(local_output));
1758 }
1759 }
1760
1761 for (int i = 0; i < function.output_index_map.size(); ++i) {
1762 // TODO(b/162744844): Generalize this pattern so that the extraction is
1763 // not special cased.
1764 if (function.shape_output_metadata.find(i) !=
1765 function.shape_output_metadata.end()) {
1766 output_with_layout[i]->set_input_layout_for_shape_op_result(
1767 function.shape_output_metadata.at(i));
1768 }
1769
1770 RecordInShapeLayoutCache(*output_with_layout[i]);
1771 typed_outputs[function.output_index_map[i]] =
1772 std::move(output_with_layout[i]);
1773 }
1774 }
1775 if (TF_GetCode(join_status.get()) != TF_OK) {
1776 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status(
1777 TF_NewStatus(), TF_DeleteStatus);
1778 AsyncWait(context, async_wait_status.get());
1779 TF_Code error_code = TF_GetCode(async_wait_status.get());
1780 if (error_code != TF_OK && error_code != TF_CANCELLED) {
1781 // Ignore the AsyncWait() status return since we already have a bad status
1782 // to propagate. We've just canceled a bunch of operations, so we expect
1783 // cancellation status returns. We'll log anything else just to be safe.
1784 LOG(ERROR) << "Error executing " << doperation.name << " "
1785 << TF_Message(async_wait_status.get());
1786 }
1787
1788 TF_SetStatus(status, TF_GetCode(join_status.get()),
1789 TF_Message(join_status.get()));
1790 return;
1791 }
1792 if (VLOG_IS_ON(2)) {
1793 LOG(INFO) << "Executed " << doperation.name << ", got "
1794 << typed_outputs.size() << " outputs:";
1795 for (const std::unique_ptr<TensorWithLayout>& output : typed_outputs) {
1796 LOG(INFO) << " " << output->DebugString();
1797 }
1798 }
1799 if (doperation.name == std::string("VarHandleOp")) {
1800 // For new variables, set the dereferenced shape/dtype so we can pass it in
1801 // as _handle_dtype and _handle_shape in the future.
1802 //
1803 // Note that VarHandleOps generated by `tf.Variable` objects are always run
1804 // eagerly, which is almost all of the op's usage in TF2. Theoretically a
1805 // user could run it in a tf.function via tf.raw_ops.VarHandleOp, return it
1806 // from that function, and add it as an input to another, and it would
1807 // currently be missing handle information.
1808 if (typed_outputs.size() != 1) {
1809 RETURN_STATUS(status, TF_INTERNAL,
1810 "Expected one output from VarHandleOp");
1811 }
1812 NameAttrList name_and_attrs;
1813 ASSIGN_OR_RETURN_C_STATUS(name_and_attrs, FetchAttributes(attributes),
1814 status);
1815
1816 typed_outputs[0]->UpdateShapeAndDType(
1817 name_and_attrs.attr().at("shape").shape(),
1818 name_and_attrs.attr().at("dtype").type(), status);
1819 if (TF_GetCode(status) != TF_OK) return;
1820 }
1821
1822 for (int i = 0; i < *num_outputs; ++i) {
1823 outputs[i] =
1824 MakeLayoutTensorHandle(context, std::move(typed_outputs[i]), status);
1825 if (TF_GetCode(status) != TF_OK) return;
1826 }
1827 }
1828
Execute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status)1829 void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs,
1830 TFE_TensorHandle** outputs, TF_Status* status) {
1831 TFE_Context* context = TFE_OpGetContext(original_op, status);
1832 if (TF_GetCode(status) != TF_OK) return;
1833 const char* operation_name = TFE_OpGetName(original_op, status);
1834 if (TF_GetCode(status) != TF_OK) return;
1835 const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
1836 int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
1837 if (TF_GetCode(status) != TF_OK) return;
1838 std::vector<TFE_TensorHandle*> inputs_vector;
1839 inputs_vector.reserve(num_inputs);
1840 for (int input_index = 0; input_index < num_inputs; ++input_index) {
1841 TFE_TensorHandle* input =
1842 TFE_OpGetFlatInput(original_op, input_index, status);
1843 if (TF_GetCode(status) != TF_OK) return;
1844 inputs_vector.push_back(input);
1845 }
1846 TFE_TensorHandle** inputs = inputs_vector.data();
1847
1848 DTensorOperation dtensor_operation{};
1849 dtensor_operation.name = operation_name;
1850 {
1851 dtensor_operation.function_def =
1852 tensorflow::unwrap(context)->FindFunctionDef(operation_name);
1853 }
1854
1855 // First handle DTensor-specific virtual operations.
1856 bool is_op_handled = false;
1857 MaybeHandleDTensorCustomOps(operation_name, num_inputs, attributes, context,
1858 inputs, num_outputs, outputs, &is_op_handled,
1859 status);
1860 if (is_op_handled) return;
1861
1862 // This isn't a special op, so we'll defer to TFE_Execute to actually execute
1863 // it, but we'll also run DTensor MLIR passes and propagate the layout.
1864 std::vector<TensorWithLayout*> typed_inputs;
1865 std::vector<std::unique_ptr<TensorWithLayout>> inputs_with_no_layout;
1866
1867 // Record a unique mesh identified through all inputs that's already on
1868 // DTensor device. If we can identify a single mesh, the same mesh is used as
1869 // the mesh to broadcast non-dtensor inputs.
1870 absl::flat_hash_set<Mesh> input_meshes;
1871 std::vector<int> not_on_device_input_indices;
1872
1873 typed_inputs.resize(num_inputs);
1874 for (int j = 0; j < num_inputs; ++j) {
1875 TFE_TensorHandle* input = inputs[j];
1876 const char* input_device = TFE_TensorHandleDeviceName(input, status);
1877 if (TF_GetCode(status) != TF_OK) return;
1878 if (name_ != input_device) {
1879 not_on_device_input_indices.push_back(j);
1880 continue;
1881 }
1882 // Handle input which is on DTensor device already.
1883 TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
1884 TFE_TensorHandleDevicePointer(input, status));
1885 if (TF_GetCode(status) != TF_OK) return;
1886
1887 // VarHandleOp runs on empty mesh, and that isn't registered with device.
1888 if (!t->layout().mesh().IsEmpty()) {
1889 input_meshes.insert(t->layout().mesh());
1890 }
1891 // Remote mesh inputs are not able to be read and evaluated.
1892 if (!is_remote_mesh(t->layout().mesh()) && !t->const_value().has_value()) {
1893 std::optional<NodeDef> const_value =
1894 ExtractSmallTensorValue(context, input, t->layout(), status);
1895 if (TF_GetCode(status) != TF_OK) return;
1896 if (const_value.has_value()) {
1897 t->set_const_value(const_value.value());
1898 }
1899 }
1900 typed_inputs[j] = t;
1901 }
1902
1903 // If a unique mesh is identified across all inputs, we use that mesh as the
1904 // mesh to broadcast to. Otherwise we fallback to default mesh.
1905 const MeshWithParallelDevice* broadcast_mesh =
1906 input_meshes.size() == 1
1907 ? mesh_to_device_map_[*input_meshes.begin()].get()
1908 : default_mesh_;
1909 if (!broadcast_mesh) {
1910 RETURN_STATUS(status, TF_INVALID_ARGUMENT,
1911 "No mesh has been registered to DTensor. Use copy_to_mesh to "
1912 "explicit specify a mesh instead.");
1913 }
1914 for (int not_on_device_input_index : not_on_device_input_indices) {
1915 TFE_TensorHandle* input = inputs[not_on_device_input_index];
1916 // DTensor creation should be explicit, with some exceptions for usability
1917 // (scalars/shapes/slice specs/etc.) Here we do some trivial validation to
1918 // enforce this rule.
1919 int num_dims = TFE_TensorHandleNumDims(input, status);
1920 if (TF_GetCode(status) != TF_OK) return;
1921 int64_t num_elements = TFE_TensorHandleNumElements(input, status);
1922 if (TF_GetCode(status) != TF_OK) return;
1923 TF_DataType dtype = TFE_TensorHandleDataType(input);
1924 const bool small_int_tensor = num_elements < kSmallTensorThreshold &&
1925 (dtype == TF_INT32 || dtype == TF_INT64);
1926 if (!(num_dims == 0 || dtype == TF_STRING || small_int_tensor)) {
1927 std::vector<int64_t> tensor_shape(TensorShapeAsVector(input, status));
1928 if (TF_GetCode(status) != TF_OK) return;
1929 RETURN_STATUS(
1930 status, TF_UNIMPLEMENTED,
1931 absl::StrCat(
1932 "The op/function ", operation_name,
1933 " got a regular tensor for input ", not_on_device_input_index,
1934 " (shape ", ShapeToDebugString(tensor_shape),
1935 ") but was expecting a DTensor. Currently only scalars and "
1936 "small integer/string tensors are auto-broadcast to "
1937 "DTensors. For other tensors, please use copy_to_mesh to "
1938 "make a DTensor explicitly; note that this may be slow if it "
1939 "happens frequently.")
1940 .c_str());
1941 }
1942 // Construct temporary TensorWithLayout objects for inputs that didn't
1943 // have any to start. These are owned by the `inputs_with_no_layout`
1944 // vector, whereas the input `TFE_TensorHandle`s maintain ownership for
1945 // inputs that already had layouts (and therefor had TensorWithLayout
1946 // objects).
1947 std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast(
1948 context, input, *broadcast_mesh, name_, status);
1949 if (TF_GetCode(status) != TF_OK) return;
1950 if (!ShouldFoldInputArgument(dtensor_operation.name,
1951 /*input_index=*/not_on_device_input_index)) {
1952 wrapper->reset_const_value();
1953 }
1954 typed_inputs[not_on_device_input_index] = wrapper.get();
1955 inputs_with_no_layout.emplace_back(wrapper.release());
1956 }
1957
1958 ExecuteRegularOperation(context, typed_inputs, dtensor_operation, attributes,
1959 num_outputs, outputs, status);
1960 }
1961
ExecuteOnDTensorDevice(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status,void * device_info)1962 void ExecuteOnDTensorDevice(const TFE_Op* original_op, int* num_outputs,
1963 TFE_TensorHandle** outputs, TF_Status* status,
1964 void* device_info) {
1965 DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info);
1966 dev->Execute(original_op, num_outputs, outputs, status);
1967 }
1968
DeleteDTensorDevice(void * device_info)1969 void DeleteDTensorDevice(void* device_info) {
1970 delete static_cast<DTensorDevice*>(device_info);
1971 }
1972
CopyToDTensorDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)1973 TFE_TensorHandle* CopyToDTensorDevice(TFE_Context* context,
1974 TFE_TensorHandle* tensor,
1975 TF_Status* status, void* device_info) {
1976 TF_SetStatus(status, TF_UNIMPLEMENTED,
1977 "Trying to copy a tensor on to a DTensor mesh without a layout "
1978 "(use the CopyToMesh op for now).");
1979 return nullptr;
1980 }
1981
CopyFromDTensorDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)1982 TFE_TensorHandle* CopyFromDTensorDevice(TFE_Context* context,
1983 TFE_TensorHandle* tensor,
1984 const char* target_device_name,
1985 TF_Status* status, void* device_info) {
1986 TensorWithLayout* typed_input = reinterpret_cast<TensorWithLayout*>(
1987 TFE_TensorHandleDevicePointer(tensor, status));
1988 if (!tensorflow::dtensor::Layout(typed_input->layout()).IsFullyReplicated()) {
1989 TF_SetStatus(status, TF_UNIMPLEMENTED,
1990 "Trying to copy a non-replicated DTensor is not supported.");
1991 return nullptr;
1992 }
1993 if (typed_input->tensor()->dtype() == TF_RESOURCE) {
1994 TF_SetStatus(status, TF_UNIMPLEMENTED,
1995 "Trying to copy a DTensor resource handle is not supported.");
1996 return nullptr;
1997 }
1998 DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info);
1999 // Since operations are executed asynchronously, the operation which should
2000 // produce the tensor we're trying to copy off the DTensor device may be
2001 // canceled due to a failure on another device. If so, we want to report the
2002 // failure that caused the cancellation, not the cancellation itself. This
2003 // requires blocking waiting for other devices to flush their execution
2004 // queues.
2005 // Note that we also only need to sync the threads on the parallel_device()
2006 // directly, or a context level sync might cause unintentional deadlocks when
2007 // grabbing locks on other threads.
2008 dev->AsyncWait(context, status);
2009 if (TF_GetCode(status) != TF_OK) return nullptr;
2010 return TFE_TensorHandleCopySharingTensor(typed_input->get_tensor(0), status);
2011 }
2012
AllocateDTensorDevice(absl::string_view device_name,TFE_CustomDevice * device,void ** device_info)2013 void AllocateDTensorDevice(absl::string_view device_name,
2014 TFE_CustomDevice* device, void** device_info) {
2015 device->copy_tensor_to_device = &CopyToDTensorDevice;
2016 device->copy_tensor_from_device = &CopyFromDTensorDevice;
2017 device->delete_device = &DeleteDTensorDevice;
2018 device->execute = &ExecuteOnDTensorDevice;
2019 *device_info = new DTensorDevice(device_name);
2020 }
2021
AddMesh(const std::string & serialized_mesh,void * device_info,bool is_async,bool is_host_mesh,TF_Status * status)2022 void AddMesh(const std::string& serialized_mesh, void* device_info,
2023 bool is_async, bool is_host_mesh, TF_Status* status) {
2024 auto mesh_config_or_status = Mesh::FromString(serialized_mesh);
2025 if (!mesh_config_or_status.ok()) {
2026 TF_SetStatus(status, TF_INTERNAL,
2027 absl::StrCat("Failed to parse mesh config. ",
2028 mesh_config_or_status.status().error_message())
2029 .c_str());
2030 return;
2031 }
2032 auto mesh_config = mesh_config_or_status.ValueOrDie();
2033 std::vector<std::string> underlying_devices;
2034 underlying_devices.insert(underlying_devices.end(),
2035 mesh_config.local_devices().begin(),
2036 mesh_config.local_devices().end());
2037 // DTensor uses multi-client setup which doesn't use remote eager, so we can
2038 // enable eager async execution in ParallelDevice.
2039 std::unique_ptr<tensorflow::parallel_device::ParallelDevice> parallel(
2040 new tensorflow::parallel_device::ParallelDevice(underlying_devices,
2041 is_async));
2042
2043 std::string composite_device_name;
2044 if (absl::StartsWith(mesh_config.name(), kPipelineMeshNamePrefix)) {
2045 composite_device_name = std::string(
2046 absl::StripPrefix(mesh_config.name(), kPipelineMeshNamePrefix));
2047 }
2048
2049 auto mesh = std::make_unique<MeshWithParallelDevice>(
2050 std::move(mesh_config), std::move(parallel), composite_device_name);
2051 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2052 device->AddMesh(std::move(mesh), is_host_mesh);
2053 }
2054
ExperimentalSetDefaultLayout(const std::string & serialized_layout,void * device_info,TF_Status * status)2055 void ExperimentalSetDefaultLayout(const std::string& serialized_layout,
2056 void* device_info, TF_Status* status) {
2057 StatusOr<Layout> layout = Layout::FromString(serialized_layout);
2058 if (!layout.ok()) {
2059 RETURN_STATUS(status, TF_INTERNAL, layout.status().error_message().c_str());
2060 }
2061 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2062 device->SetDefaultLayout(layout.ValueOrDie());
2063 }
2064
ExperimentalClearDefaultLayout(void * device_info,TF_Status * status)2065 void ExperimentalClearDefaultLayout(void* device_info, TF_Status* status) {
2066 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2067 device->ClearDefaultLayout();
2068 }
2069
ExperimentalSetDefaultMesh(const std::string & serialized_mesh,void * device_info,TF_Status * status)2070 void ExperimentalSetDefaultMesh(const std::string& serialized_mesh,
2071 void* device_info, TF_Status* status) {
2072 StatusOr<Mesh> mesh = Mesh::FromString(serialized_mesh);
2073 if (!mesh.ok()) {
2074 RETURN_STATUS(status, TF_INTERNAL, mesh.status().error_message().c_str());
2075 }
2076 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2077 device->SetDefaultMesh(mesh.ValueOrDie());
2078 }
2079
ExperimentalClearDefaultMesh(void * device_info,TF_Status * status)2080 void ExperimentalClearDefaultMesh(void* device_info, TF_Status* status) {
2081 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2082 device->ClearDefaultMesh();
2083 }
2084
SetSameShapePolicy(void * device_info,bool enabled)2085 void SetSameShapePolicy(void* device_info, bool enabled) {
2086 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2087 device->SetSameShapePolicy(enabled);
2088 }
2089
SetTPUCoreIDs(const std::string & mesh_name,const std::vector<int> & tpu_core_ids,void * device_info,TF_Status * status)2090 void SetTPUCoreIDs(const std::string& mesh_name,
2091 const std::vector<int>& tpu_core_ids, void* device_info,
2092 TF_Status* status) {
2093 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2094 RETURN_C_STATUS_IF_NOT_OK(device->SetTPUCoreIDs(mesh_name, tpu_core_ids),
2095 status);
2096 }
2097
ClearTPUCoreIDs(void * device_info)2098 void ClearTPUCoreIDs(void* device_info) {
2099 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2100 device->ClearTPUCoreIDs();
2101 }
2102
TPUCoreIDsToLocations(TFE_Context * context,const std::vector<int> & tpu_core_ids,void * device_info)2103 std::vector<std::vector<int>> TPUCoreIDsToLocations(
2104 TFE_Context* context, const std::vector<int>& tpu_core_ids,
2105 void* device_info) {
2106 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2107 return device->TPUCoreIDsToLocations(context, tpu_core_ids);
2108 }
2109
TPUCoreLocationsToIDs(TFE_Context * context,const std::vector<std::vector<int>> & tpu_core_locations,void * device_info)2110 std::vector<int> TPUCoreLocationsToIDs(
2111 TFE_Context* context,
2112 const std::vector<std::vector<int>>& tpu_core_locations,
2113 void* device_info) {
2114 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2115 return device->TPUCoreLocationsToIDs(context, tpu_core_locations);
2116 }
2117
Pack(TFE_Context * context,int num_inputs,TFE_TensorHandle ** inputs,const std::string & string_layout,void * device_info,TF_Status * status)2118 TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs,
2119 TFE_TensorHandle** inputs,
2120 const std::string& string_layout, void* device_info,
2121 TF_Status* status) {
2122 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2123 return device->Pack(context, num_inputs, inputs, string_layout, status);
2124 }
2125
Unpack(TFE_Context * context,TFE_TensorHandle * input,void * device_info,TF_Status * status)2126 std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context,
2127 TFE_TensorHandle* input,
2128 void* device_info, TF_Status* status) {
2129 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2130 return device->Unpack(context, input, status);
2131 }
2132
FetchLayout(TFE_Context * context,TFE_TensorHandle * input,void * device_info,TF_Status * status)2133 std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input,
2134 void* device_info, TF_Status* status) {
2135 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2136 return device->FetchLayout(context, input, status);
2137 }
2138
SparsePack(TFE_Context * context,int num_inputs,TFE_TensorHandle ** indices,TFE_TensorHandle ** values,TFE_TensorHandle ** shapes,const std::string & string_layout,void * device_info,TF_Status * status)2139 TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs,
2140 TFE_TensorHandle** indices,
2141 TFE_TensorHandle** values,
2142 TFE_TensorHandle** shapes,
2143 const std::string& string_layout,
2144 void* device_info, TF_Status* status) {
2145 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2146 return device->SparsePack(context, num_inputs, indices, values, shapes,
2147 string_layout, status);
2148 }
2149
IsSparseDTensor(TFE_Context * context,TFE_TensorHandle * input,void * device_info,TF_Status * status)2150 bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input,
2151 void* device_info, TF_Status* status) {
2152 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2153 return device->IsSparseDTensor(context, input, status);
2154 }
2155
GetFunctionCacheHitAndMissCount(TFE_Context * context,void * device_info,TF_Status * status)2156 std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount(
2157 TFE_Context* context, void* device_info, TF_Status* status) {
2158 DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2159 return device->GetFunctionCacheHitAndMissCount(context, status);
2160 }
2161 } // namespace dtensor
2162 } // namespace tensorflow
2163