xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/dtensor_device.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/cc/dtensor_device.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/base/attributes.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/match.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/string_view.h"
33 #include "absl/strings/strip.h"
34 #include "tensorflow/c/c_api_experimental.h"
35 #include "tensorflow/c/eager/c_api.h"
36 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
37 #include "tensorflow/c/eager/tfe_context_internal.h"
38 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
39 #include "tensorflow/c/tf_datatype.h"
40 #include "tensorflow/c/tf_status.h"
41 #include "tensorflow/c/tf_status_helper.h"
42 #include "tensorflow/c/tf_tensor_internal.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h"
45 #include "tensorflow/core/common_runtime/device_set.h"
46 #include "tensorflow/core/common_runtime/eager/context.h"
47 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
48 #include "tensorflow/core/common_runtime/graph_constructor.h"
49 #include "tensorflow/core/common_runtime/shape_refiner.h"
50 #include "tensorflow/core/framework/attr_value.pb.h"
51 #include "tensorflow/core/framework/function.h"
52 #include "tensorflow/core/framework/function.pb.h"
53 #include "tensorflow/core/framework/graph_to_functiondef.h"
54 #include "tensorflow/core/framework/node_def_builder.h"
55 #include "tensorflow/core/framework/node_def_util.h"
56 #include "tensorflow/core/framework/op.h"
57 #include "tensorflow/core/framework/tensor_shape.h"
58 #include "tensorflow/core/graph/algorithm.h"
59 #include "tensorflow/core/graph/graph.h"
60 #include "tensorflow/core/lib/strings/proto_serialization.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/fingerprint.h"
63 #include "tensorflow/core/platform/types.h"
64 #include "tensorflow/core/profiler/lib/traceme.h"
65 #include "tensorflow/core/util/dump_graph.h"
66 #include "tensorflow/dtensor/cc/constants.h"
67 #include "tensorflow/dtensor/cc/dstatus.h"
68 #include "tensorflow/dtensor/cc/dtensor_device_util.h"
69 #include "tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h"
70 #include "tensorflow/dtensor/cc/small_constant_optimization.h"
71 #include "tensorflow/dtensor/cc/tensor_layout.h"
72 #include "tensorflow/dtensor/cc/tpu_system_interface.h"
73 #include "tensorflow/dtensor/proto/layout.pb.h"
74 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
75 #include "tensorflow/stream_executor/tpu/tpu_topology.h"
76 
77 namespace tensorflow {
78 namespace dtensor {
79 
80 // TODO(b/189332820): Replace this with a Partitioner stub swapped in by the
81 // Copybara workflow.
PipeliningPartitionerRun(const absl::flat_hash_map<std::string,const MeshWithParallelDevice * > * device_name_to_mesh_device,FunctionLibraryDefinition * flib_def,DTensorMlirPassRunner * pass_runner,const FunctionDef & fdef,const NameAttrList & eager_attributes,const std::vector<TensorWithLayout * > & inputs,const DeviceSet & device_set,int num_outputs)82 StatusOr<ExecutionFunctions> ABSL_ATTRIBUTE_WEAK PipeliningPartitionerRun(
83     const absl::flat_hash_map<std::string, const MeshWithParallelDevice*>*
84         device_name_to_mesh_device,
85     FunctionLibraryDefinition* flib_def, DTensorMlirPassRunner* pass_runner,
86     const FunctionDef& fdef, const NameAttrList& eager_attributes,
87     const std::vector<TensorWithLayout*>& inputs, const DeviceSet& device_set,
88     int num_outputs) {
89   // The actual definition is in the pipelining package.
90   return errors::Unimplemented("DTensor pipelining is unavailable.");
91 }
92 
93 class DTensorDevice {
94  public:
DTensorDevice(absl::string_view name)95   explicit DTensorDevice(absl::string_view name)
96       : name_(name),
97         same_shape_policy_enabled_(false),
98         cancellation_manager_(std::make_unique<CancellationManager>()) {}
99 
AddMesh(std::unique_ptr<MeshWithParallelDevice> mesh,bool is_host_mesh)100   void AddMesh(std::unique_ptr<MeshWithParallelDevice> mesh,
101                bool is_host_mesh) {
102     // TODO(b/168730933): Consider passing a cheaper int64_t mesh identifier.
103     if (is_host_mesh) {
104       std::string& tpu_host_mesh = Mesh::tpu_host_mesh();
105       const std::string new_tpu_host_mesh = mesh->mesh_config().ToString();
106       if (!tpu_host_mesh.empty()) {
107         // TODO(b/180046115): Add per-TPU-mesh host mesh bookkeeping.
108         LOG(WARNING)
109             << "A new TPU host mesh is overwriting the old TPU host mesh. The "
110                "old TPU mesh cannot be used in sea of donuts mode anymore.";
111       }
112       tpu_host_mesh.assign(new_tpu_host_mesh);
113     }
114     // For idempotency, don't register the same mesh twice.
115     if (!mesh_to_device_map_.insert({mesh->mesh_config(), std::move(mesh)})
116              .second)
117       return;
118     if (!default_mesh_) {
119       global_default_mesh_ = mesh_to_device_map_.begin()->second.get();
120       default_mesh_ = global_default_mesh_;
121     }
122   }
123 
124   // Returns sub meshes of pipelining.
125   // Key is the name of a composite device.
126   StatusOr<absl::flat_hash_map<std::string, const MeshWithParallelDevice*>>
PipelineSubMeshes(TFE_Context * context)127   PipelineSubMeshes(TFE_Context* context) {
128     absl::flat_hash_map<std::string, const MeshWithParallelDevice*>
129         device_to_mesh;
130     for (const auto& pair : mesh_to_device_map_) {
131       TF_ASSIGN_OR_RETURN(CompositeDevice * device,
132                           pair.second->FindOrCreateCompositeDevice(context));
133       if (device != nullptr) {
134         device_to_mesh[pair.second->composite_device()->name()] =
135             pair.second.get();
136       }
137     }
138     return device_to_mesh;
139   }
140 
141   // Runs an operation on the DTensorDevice,
142   //
143   // Ignoring the placement of the original op (TFE_OpGetDevice(original_op)).
144   // This indicates whether the user explicitly placed the op on the DTensor
145   // device (vs. having it placed on the DTensor device because an input was
146   // placed there), but DTensor is doing type-based dispatch and so handles
147   // these cases identically at the moment.
148   void Execute(const TFE_Op* original_op, int* num_outputs,
149                TFE_TensorHandle** outputs, TF_Status* status);
150 
SetDefaultLayout(Layout layout)151   void SetDefaultLayout(Layout layout) { default_layout_.emplace(layout); }
ClearDefaultLayout()152   void ClearDefaultLayout() { default_layout_.reset(); }
SetDefaultMesh(Mesh mesh)153   void SetDefaultMesh(Mesh mesh) {
154     default_mesh_ = mesh_to_device_map_.at(mesh).get();
155   }
ClearDefaultMesh()156   void ClearDefaultMesh() { default_mesh_ = global_default_mesh_; }
SetSameShapePolicy(bool enabled)157   void SetSameShapePolicy(bool enabled) {
158     same_shape_policy_enabled_ = enabled;
159   }
160 
SetTPUCoreIDs(const std::string & mesh_name,const std::vector<int> & tpu_core_ids)161   Status SetTPUCoreIDs(const std::string& mesh_name,
162                        const std::vector<int>& tpu_core_ids) {
163     if (VLOG_IS_ON(1)) {
164       LOG(INFO) << "Setting TPU core IDs for "
165                 << (mesh_name.empty() ? "default mesh" : mesh_name) << ": ";
166       for (auto i : tpu_core_ids) {
167         LOG(INFO) << i;
168       }
169     }
170     // Setting the default mesh under an empty name repeatedly is fine, which
171     // happens when dtensor_initialize_tpu_system is called multiple times
172     // especially in tests. All the set mappings should be the same anyway.
173     if (!mesh_name.empty() && Mesh::tpu_core_ids().count(mesh_name) > 0) {
174       return errors::AlreadyExists("Mesh name already in use: ", mesh_name);
175     }
176     Mesh::tpu_core_ids()[mesh_name].assign(tpu_core_ids.begin(),
177                                            tpu_core_ids.end());
178     return OkStatus();
179   }
180 
ClearTPUCoreIDs()181   void ClearTPUCoreIDs() { Mesh::tpu_core_ids().clear(); }
182 
TPUCoreIDsToLocations(TFE_Context * context,const std::vector<int> & tpu_core_ids)183   std::vector<std::vector<int>> TPUCoreIDsToLocations(
184       TFE_Context* context, const std::vector<int>& tpu_core_ids) {
185     TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
186     if (tpu_system == nullptr) {
187       VLOG(1) << "Calling TPUCoreIDsToLocations on the default TPU system.";
188       std::vector<std::vector<int>> tpu_core_locations;
189       tpu_core_locations.reserve(tpu_core_ids.size());
190       tpu::TpuPlatformInterface* tpu_platform =
191           tpu::TpuPlatformInterface::GetRegisteredPlatform();
192       if (tpu_platform == nullptr) {
193         LOG(WARNING) << "No TPU platform is found.";
194         return {{}};
195       }
196       if (!tpu_platform->Initialized()) {
197         LOG(WARNING) << "TPU platform is not initialized.";
198         return {{}};
199       }
200       tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology();
201 
202       for (const int& tpu_core_id : tpu_core_ids) {
203         tpu::TpuCoreLocationExternal core =
204             tpu_topology.CoreForId(TpuCoreTypeEnum::kTensorCore, tpu_core_id);
205         tpu::TpuDimensionsExternal tpu_chip_location = core.chip_coordinates();
206         tpu_core_locations.push_back({tpu_chip_location.x, tpu_chip_location.y,
207                                       tpu_chip_location.z, core.index()});
208       }
209       return tpu_core_locations;
210     } else {
211       VLOG(1) << "Calling TPUCoreIDsToLocations on a preferred TPU system.";
212       return tpu_system->TPUCoreIDsToLocations(context, tpu_core_ids);
213     }
214   }
215 
TPUCoreLocationsToIDs(TFE_Context * context,const std::vector<std::vector<int>> & tpu_core_locations)216   std::vector<int> TPUCoreLocationsToIDs(
217       TFE_Context* context,
218       const std::vector<std::vector<int>>& tpu_core_locations) {
219     TpuSystemInterface* tpu_system = GetPreferredTpuSystem();
220     if (tpu_system == nullptr) {
221       VLOG(1) << "Calling TPUCoreLocationsToIDs on the default TPU system.";
222       std::vector<int> tpu_core_ids;
223       tpu_core_ids.reserve(tpu_core_locations.size());
224       tpu::TpuPlatformInterface* tpu_platform =
225           tpu::TpuPlatformInterface::GetRegisteredPlatform();
226       if (tpu_platform == nullptr) {
227         LOG(WARNING) << "No TPU platform is found.";
228         return {};
229       }
230       if (!tpu_platform->Initialized()) {
231         LOG(WARNING) << "TPU platform is not initialized.";
232         return {};
233       }
234       tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology();
235 
236       for (const std::vector<int>& tpu_core_location : tpu_core_locations) {
237         tpu::TpuCoreLocationExternal core = tpu_topology.Core(
238             TpuCoreTypeEnum::kTensorCore, tpu_core_location[0],
239             tpu_core_location[1], tpu_core_location[2], tpu_core_location[3]);
240         tpu_core_ids.push_back(core.Id());
241       }
242       return tpu_core_ids;
243     } else {
244       VLOG(1) << "Calling TPUCoreLocationsToIDs on a preferred TPU system.";
245       return tpu_system->TPUCoreLocationsToIDs(context, tpu_core_locations);
246     }
247   }
248 
249   // Waits for ops to finish in ALL meshes as we share the cancellation manager.
AsyncWait(TFE_Context * context,TF_Status * status)250   void AsyncWait(TFE_Context* context, TF_Status* status) {
251     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> first_bad_status(
252         nullptr, TF_DeleteStatus);
253 
254     for (const auto& pair : mesh_to_device_map_) {
255       std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status(
256           TF_NewStatus(), TF_DeleteStatus);
257 
258       pair.second->parallel_device().AsyncWait(context,
259                                                async_wait_status.get());
260 
261       TF_Code error_code = TF_GetCode(async_wait_status.get());
262       if (error_code != TF_OK &&
263           (first_bad_status == nullptr ||
264            TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) {
265         first_bad_status.reset(TF_NewStatus());
266         TF_SetStatus(first_bad_status.get(), error_code,
267                      TF_Message(async_wait_status.get()));
268       }
269     }
270 
271     if (first_bad_status != nullptr) {
272       TF_SetStatus(status, TF_GetCode(first_bad_status.get()),
273                    TF_Message(first_bad_status.get()));
274     }
275 
276     // Reset the global function rendezvous, which otherwise stores a failure
277     // state.
278     tensorflow::unwrap(context)->ResetGlobalRendezvousForFunction();
279 
280     // Reset the cancellation manager on (potential) failure so we don't cancel
281     // future ops. This is only safe because we have just cleared pending async
282     // nodes, which may have had a reference to he cancellation manager.
283     cancellation_manager_ = std::make_unique<CancellationManager>();
284   }
285 
286   TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs,
287                          TFE_TensorHandle** inputs,
288                          const std::string& string_layout, TF_Status* status);
289 
290   std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context,
291                                         TFE_TensorHandle* input,
292                                         TF_Status* status);
293 
294   // Return the layout for the input tensor.
295   std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input,
296                           TF_Status* status);
297 
298   TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs,
299                                TFE_TensorHandle** indices,
300                                TFE_TensorHandle** values,
301                                TFE_TensorHandle** shapes,
302                                const std::string& string_layout,
303                                TF_Status* status);
304 
305   bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input,
306                        TF_Status* status);
307 
308   std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount(
309       TFE_Context* context, TF_Status* status) const;
310 
311  private:
312   // If the `operation_name` of an op indicates a custom DTensor op (e.g.
313   // CopyToMesh), then separately handle those custom ops instead of running
314   // default DTensor graph compilation.
315   void MaybeHandleDTensorCustomOps(
316       const char* operation_name, const int num_inputs,
317       const TFE_OpAttrs* attributes, TFE_Context* context,
318       TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs,
319       bool* is_custom_dtensor_op, TF_Status* status);
320 
321   // Copies non-dtensor eager tensor or DTensor to a mesh specified by
322   // `attributes`.
323   // Currently, only copy to replicated layout on target mesh is supported.
324   void CopyToMesh(TFE_Context* context, int num_inputs,
325                   TFE_TensorHandle** inputs, const TFE_OpAttrs* attributes,
326                   TFE_TensorHandle** outputs, int* num_outputs,
327                   TF_Status* status);
328 
329   // Update output layouts for eager ops based on same shape policy.
330   void UpdateOutputLayoutsWithSameShapePolicy(
331       const std::vector<PartialTensorShape>& global_output_shapes,
332       const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name,
333       tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts,
334       TF_Status* status);
335 
336   // Takes the description of an operation and makes a function out of it with
337   // the same signature, running DTensor MLIR passes. Registers that function
338   // with `context`. `translated_function_name` is set to the name of the
339   // function.
340   //
341   // The resulting function expects a device ID as its first input.
342   void LowerToSPMDFunction(TFE_Context* context,
343                            const std::vector<TensorWithLayout*>& inputs,
344                            const DTensorOperation& doperation,
345                            const TFE_OpAttrs* attributes, const int num_outputs,
346                            const ExecutionFunctions** execution_functions,
347                            TF_Status* status);
348 
349   // Execute a given function.
350   void ExecuteFunctionAndWait(
351       TFE_Context* context, const TranslatedFunction* function_ptr,
352       const MeshWithParallelDevice* parallel_device_mesh,
353       const std::vector<parallel_device::ParallelTensor*>& parallel_inputs,
354       const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status);
355 
356   // Implements `Execute` for operations which aren't special-cased in
357   void ExecuteRegularOperation(TFE_Context* context,
358                                const std::vector<TensorWithLayout*>& inputs,
359                                const DTensorOperation& doperation,
360                                const TFE_OpAttrs* attributes, int* num_outputs,
361                                TFE_TensorHandle** outputs, TF_Status* status);
362 
363   // Wraps a TensorWithLayout into a TFE_TensorHandle.
364   TFE_TensorHandle* MakeLayoutTensorHandle(TFE_Context* context,
365                                            std::unique_ptr<TensorWithLayout> t,
366                                            TF_Status* status);
367 
368   void RecordInShapeLayoutCache(const TensorWithLayout& tensor);
369 
370   // Returns whether a given mesh is a remote mesh.
371   bool is_remote_mesh(const Mesh& mesh) const;
372 
373   // The name of the device (the custom device)
374   std::string name_;
375   // Mesh configs with matching parallel devices.
376   //
377   // For now we just consider the first entry added to dtensor_device as the
378   // default mesh. Before we reach an agreement on this, we'll leave it as is.
379   absl::flat_hash_map<Mesh, std::unique_ptr<MeshWithParallelDevice>>
380       mesh_to_device_map_;
381   // TODO(hthu): Consider whether we want to preserve the default_mesh semantic.
382   // Current default mesh consistent to default_layout_. If default_layout_ is
383   // not set, it equals to global_default_mesh_.
384   const MeshWithParallelDevice* default_mesh_ = nullptr;
385   // The default mesh of a DTensorDevice, which cannot be modified once being
386   // set.
387   const MeshWithParallelDevice* global_default_mesh_ = nullptr;
388   // If the user has specified a default output layout.
389   absl::optional<Layout> default_layout_;
390 
391   // Determines whether tensors with a shape previously associated with only one
392   // layout use that layout if nothing else can be inferred.
393   bool same_shape_policy_enabled_;
394 
395   DTensorMlirPassRunner pass_runner_;
396 
397   struct CachedLayout {
398     // The first layout seen with this shape
399     Layout layout;
400     // Whether the layout is unique for this shape
401     bool is_unique;
402   };
403   absl::flat_hash_map<int64_t, CachedLayout> shape_layout_cache_;
404 
405   FunctionManager function_manager_;
406 
407   // Records the function compilation cache hits and misses.
408   std::unordered_map<std::string, int> function_compilation_hits_and_misses_;
409 
410   // Coordinates cancelling ops across meshes on error. Must outlive any queued
411   // async op launches, so we only reset it after seeing a failure status.
412   std::unique_ptr<CancellationManager> cancellation_manager_;
413 
414   // Map each function_mesh_fingerprint (based on the set of the mesh involved)
415   // to the number of times of the function execution. The
416   // function_mesh_fingerprint and the counter together are used for generating
417   // the step id, which is used for rendezvous creation.
418   absl::flat_hash_map<uint64, uint64> func_mesh_fingerprint_to_step_counter_;
419 };
420 
FingerprintShape(const absl::Span<const int64_t> shape)421 int64_t FingerprintShape(const absl::Span<const int64_t> shape) {
422   int64_t fprint = 0;
423   for (int64_t dim : shape) {
424     fprint = FingerprintCat64(fprint, dim);
425   }
426   return fprint;
427 }
428 
DeviceIDs(TFE_Context * context,TF_Status * status) const429 parallel_device::ParallelTensor* MeshWithParallelDevice::DeviceIDs(
430     TFE_Context* context, TF_Status* status) const {
431   if (device_ids_tensor_ == nullptr) {
432     // Global device IDs sequentially increase.
433     //
434     // This is the assumption in the dtensor software stack. MLIR pass relies on
435     // this assumption to generate mesh coordinates for each core efficiently.
436     //
437     // The rule to set local ids and the mapping from global ids to real
438     // physical core index, e.g., TPU, is nontrivial unfortunately. It is
439     // possible to set identical mapping but the collective operation
440     // performance is terrible for most of cases.
441     //
442     // - For ICI-connected TPU slice, see go/dtensor-device-assignment-summary
443     //   for guide how to create efficient core assignments toward peak
444     //   performance.
445     //
446     //   The global id to core assignment mapping is bridged by
447     //   `Mesh::tpu_core_ids()` and consumed by `UpdateTPUCompileMetadata`.
448     //
449     // - For DCN-connected topology, we need to map different sections of the
450     //   global ids to its real physical cores separately according to the
451     //   runtime requirements. For example, for a 4x32 mesh, in which the outer
452     //   dimension is connected via DCN and inner dimension is connected by ICI,
453     //   the device assignments for inner dimension should typically form its
454     //   own ring order (not plain physical core index) in each sub-meshes and
455     //   the outer dimension should be assigned according to the real physical
456     //   ring of DNC hosts.
457     //
458     // Note: In order to change this assumption, MLIR pass needs adjustment. One
459     // possible approach is to take a N-D mapping vector for N-D mesh and lookup
460     // the coordinates in MLIR, by consulting tensor layout as well, rather than
461     // calculation on-the-fly.
462 
463     // LINT.IfChange
464     for (int64_t i = 0; i < mesh_config_.global_device_ids().size(); ++i) {
465       if (mesh_config_.global_device_ids()[i] - i !=
466           mesh_config_.global_device_ids()[0]) {
467         TF_SetStatus(
468             status, TF_INTERNAL,
469             absl::StrCat("Global device IDs should be consecutive: ",
470                          absl::StrJoin(mesh_config_.global_device_ids(), ", "))
471                 .c_str());
472         return nullptr;
473       }
474     }
475     // LINT.ThenChange(//tensorflow/dtensor/python/layout.py)
476 
477     // Local device IDs are a subset of global device IDs, arranged in device
478     // ordinal order.
479     std::vector<int32_t> ids;
480     for (int64_t id : mesh_config_.local_device_ids()) {
481       ids.push_back(id);
482     }
483     VLOG(1) << "Parallel device IDs: " << absl::StrJoin(ids, ", ");
484     device_ids_tensor_ =
485         parallel_device_->ScalarsFromSequence<int32_t>(ids, context, status);
486     if (TF_GetCode(status) != TF_OK) return nullptr;
487   }
488   return device_ids_tensor_.get();
489 }
490 
TensorWithLayoutNumDims(void * data,TF_Status * status)491 int TensorWithLayoutNumDims(void* data, TF_Status* status) {
492   return reinterpret_cast<TensorWithLayout*>(data)->global_shape().size();
493 }
494 
TensorWithLayoutDim(void * data,int dim_index,TF_Status * status)495 int64_t TensorWithLayoutDim(void* data, int dim_index, TF_Status* status) {
496   return reinterpret_cast<TensorWithLayout*>(data)->global_shape()[dim_index];
497 }
498 
TensorWithLayoutDeallocator(void * data)499 void TensorWithLayoutDeallocator(void* data) {
500   delete reinterpret_cast<TensorWithLayout*>(data);
501 }
502 
TensorWithLayoutSummarize(void * data,TF_Status * status)503 TF_Buffer* TensorWithLayoutSummarize(void* data, TF_Status* status) {
504   std::string summary =
505       reinterpret_cast<TensorWithLayout*>(data)->SummarizeValue();
506   return TF_NewBufferFromString(summary.data(), summary.size());
507 }
508 
MakeLayoutTensorHandle(TFE_Context * context,std::unique_ptr<TensorWithLayout> t,TF_Status * status)509 TFE_TensorHandle* DTensorDevice::MakeLayoutTensorHandle(
510     TFE_Context* context, std::unique_ptr<TensorWithLayout> t,
511     TF_Status* status) {
512   TF_DataType dtype = t->dtype();
513   TFE_CustomDeviceTensorHandleMethods handle_methods;
514   handle_methods.num_dims = &TensorWithLayoutNumDims;
515   handle_methods.dim = &TensorWithLayoutDim;
516   handle_methods.deallocator = &TensorWithLayoutDeallocator;
517   handle_methods.summarize = &TensorWithLayoutSummarize;
518   return TFE_NewCustomDeviceTensorHandle(context, name_.c_str(), dtype,
519                                          /*data=*/t.release(), handle_methods,
520                                          status);
521 }
522 
RecordInShapeLayoutCache(const TensorWithLayout & tensor)523 void DTensorDevice::RecordInShapeLayoutCache(const TensorWithLayout& tensor) {
524   auto existing = shape_layout_cache_.insert(
525       {FingerprintShape(tensor.global_shape()),
526        CachedLayout{tensor.layout(), /*is_unique=*/true}});
527 
528   if (!existing.second) {
529     // There is an entry already; if the layout doesn't match we should record
530     // the fact that it's not unique.
531     if (tensor.layout() != existing.first->second.layout) {
532       existing.first->second.is_unique = false;
533     }
534   }
535 }
536 
is_remote_mesh(const Mesh & mesh) const537 bool DTensorDevice::is_remote_mesh(const Mesh& mesh) const {
538   // An empty mesh might be assigned to VarHandleOp during DTensor MLIR lowering
539   // pass. Decide whether the empty mesh is remote based on the current default
540   // mesh.
541   return mesh.is_remote() ||
542          (mesh.IsEmpty() && default_mesh_->mesh_config().is_remote());
543 }
544 
FetchAttributes(const TFE_OpAttrs * attributes)545 StatusOr<NameAttrList> FetchAttributes(const TFE_OpAttrs* attributes) {
546   // TODO(allenl): Should we just give up on the public C API to save on
547   // serialization/deserialization? We need all of the attributes and to treat
548   // them generically, which isn't going to be pleasant with typed attribute
549   // methods.
550   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> serialized_attributes(
551       TF_NewBuffer(), TF_DeleteBuffer);
552 
553   TF_Status* status = TF_NewStatus();
554   TFE_OpAttrsSerialize(attributes, serialized_attributes.get(), status);
555   if (TF_GetCode(status) == TF_OK) {
556     TF_DeleteStatus(status);
557   } else {
558     Status failure_status = StatusFromTF_Status(status);
559     TF_DeleteStatus(status);
560     return failure_status;
561   }
562 
563   NameAttrList name_and_attrs;
564   if (!name_and_attrs.ParseFromArray(serialized_attributes->data,
565                                      serialized_attributes->length)) {
566     return tensorflow::errors::Unknown("Could not parse attributes");
567   }
568   return name_and_attrs;
569 }
570 
FetchLayoutFromAttributes(const TFE_OpAttrs * attributes,absl::string_view attribute_name)571 StatusOr<Layout> FetchLayoutFromAttributes(const TFE_OpAttrs* attributes,
572                                            absl::string_view attribute_name) {
573   // Get attributes.
574   TF_ASSIGN_OR_RETURN(NameAttrList name_and_attrs, FetchAttributes(attributes));
575 
576   // Get layout string from attributes.
577   absl::string_view layout_str =
578       name_and_attrs.attr().find(std::string(attribute_name))->second.s();
579 
580   // This would probably be slow at the moment without caching.
581   // We should consider making this faster in the future.
582   return Layout::FromString(string(layout_str));
583 }
584 
FetchLayout(TFE_Context * context,TFE_TensorHandle * input,TF_Status * status)585 std::string DTensorDevice::FetchLayout(TFE_Context* context,
586                                        TFE_TensorHandle* input,
587                                        TF_Status* status) {
588   VLOG(1) << "Checking layout...";
589   const char* input_device = TFE_TensorHandleDeviceName(input, status);
590   if (input_device != name_) {
591     TF_SetStatus(status, TF_INVALID_ARGUMENT,
592                  "FetchLayout expects a tensor placed on the layout device.");
593     return {};
594   }
595   TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
596       TFE_TensorHandleDevicePointer(input, status));
597   if (TF_GetCode(status) != TF_OK) return {};
598   return t->layout().ToString();
599 }
600 
Unpack(TFE_Context * context,TFE_TensorHandle * input,TF_Status * status)601 std::vector<TFE_TensorHandle*> DTensorDevice::Unpack(TFE_Context* context,
602                                                      TFE_TensorHandle* input,
603                                                      TF_Status* status) {
604   std::vector<TFE_TensorHandle*> outputs;
605 
606   const char* input_device = TFE_TensorHandleDeviceName(input, status);
607   if (TF_GetCode(status) != TF_OK) return outputs;
608   if (input_device != name_) {
609     TF_SetStatus(
610         status, TF_INVALID_ARGUMENT,
611         absl::StrCat(
612             "DTensorUnpack expects a tensor placed on the DTensor device: ",
613             name_, ", but input was placed on device: ", input_device)
614             .c_str());
615     return outputs;
616   }
617   TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
618       TFE_TensorHandleDevicePointer(input, status));
619   if (TF_GetCode(status) != TF_OK) return outputs;
620 
621   if (is_remote_mesh(t->mesh().mesh_config())) {
622     TF_SetStatus(status, TF_UNIMPLEMENTED,
623                  "DTensorUnpack is not supported on a remote mesh.");
624     return outputs;
625   }
626   const int output_size = t->num_tensors();
627   outputs.resize(output_size);
628 
629   for (int output_index = 0; output_index < output_size; ++output_index) {
630     outputs[output_index] =
631         TFE_TensorHandleCopySharingTensor(t->get_tensor(output_index), status);
632     if (TF_GetCode(status) != TF_OK) {
633       return outputs;
634     }
635   }
636   return outputs;
637 }
638 
MaybeHandleDTensorCustomOps(const char * operation_name,const int num_inputs,const TFE_OpAttrs * attributes,TFE_Context * context,TFE_TensorHandle ** inputs,int * num_outputs,TFE_TensorHandle ** outputs,bool * is_custom_dtensor_op,TF_Status * status)639 void DTensorDevice::MaybeHandleDTensorCustomOps(
640     const char* operation_name, const int num_inputs,
641     const TFE_OpAttrs* attributes, TFE_Context* context,
642     TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs,
643     bool* is_custom_dtensor_op, TF_Status* status) {
644   *is_custom_dtensor_op = true;
645   if (operation_name == std::string("_EagerConst")) {
646     // Op-by-op const has no obvious layout. DTensor skips an SPMD expansion and
647     // instead relies on copy-on when the value is used later.
648     std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
649         TFE_NewOp(context, operation_name, status), TFE_DeleteOp);
650     if (TF_GetCode(status) != TF_OK) return;
651     for (int input_index = 0; input_index < num_inputs; ++input_index) {
652       TFE_OpAddInput(op.get(), inputs[input_index], status);
653       if (TF_GetCode(status) != TF_OK) return;
654     }
655     TFE_OpAddAttrs(op.get(), attributes);
656     TFE_Execute(op.get(), outputs, num_outputs, status);
657     return;
658   }
659   if (operation_name == std::string("CopyToMesh")) {
660     CopyToMesh(context, num_inputs, inputs, attributes, outputs, num_outputs,
661                status);
662     return;
663   }
664 
665   *is_custom_dtensor_op = false;
666 }
667 
CopyToMesh(TFE_Context * context,int num_inputs,TFE_TensorHandle ** inputs,const TFE_OpAttrs * attributes,TFE_TensorHandle ** outputs,int * num_outputs,TF_Status * status)668 void DTensorDevice::CopyToMesh(TFE_Context* context, int num_inputs,
669                                TFE_TensorHandle** inputs,
670                                const TFE_OpAttrs* attributes,
671                                TFE_TensorHandle** outputs, int* num_outputs,
672                                TF_Status* status) {
673   if (num_inputs != 1) {
674     RETURN_STATUS(status, TF_INVALID_ARGUMENT,
675                   "DTensor CopyToMesh requires exactly 1 input.");
676   }
677   if (*num_outputs < 1) {
678     RETURN_STATUS(status, TF_INTERNAL,
679                   "DTensor CopyToMesh must have output buffer to allocate at "
680                   "least 1 output.");
681   }
682 
683   // Assign layout.
684   StatusOr<Layout> target_layout_or =
685       FetchLayoutFromAttributes(attributes, kQualifiedLayoutAttr);
686   if (!target_layout_or.ok()) {
687     RETURN_STATUS(status, TF_INVALID_ARGUMENT,
688                   "DTensor CopyToMesh requires valid layout attribute for "
689                   "destination DTensor.");
690   }
691 
692   const Layout target_layout = *target_layout_or;
693   const Mesh& target_mesh = target_layout.mesh();
694 
695   // TODO(b/193443769): Support sharded layout for eager copy to mesh.
696   if (!target_layout.IsFullyReplicated()) {
697     RETURN_STATUS(status, TF_UNIMPLEMENTED,
698                   "Target layout of DTensor CopyToMesh must be replicated. "
699                   "Consider changing the target layout to replicated layout or "
700                   "file a bug to the DTensor team (b/193443769).");
701   }
702 
703   TFE_TensorHandle* input_tensor = inputs[0];
704 
705   // Check that if input tensor is DTensor, then input layout of the DTensor
706   // must be replicated.
707   const char* input_device = TFE_TensorHandleDeviceName(input_tensor, status);
708   if (TF_GetCode(status) != TF_OK) return;
709 
710   if (name_ == input_device) {
711     // Handle input which is on DTensor device already.
712     TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
713         TFE_TensorHandleDevicePointer(input_tensor, status));
714     if (TF_GetCode(status) != TF_OK) return;
715 
716     if (!t->layout().IsFullyReplicated())
717       RETURN_STATUS(status, TF_INVALID_ARGUMENT,
718                     "Input tensor to CopyToMesh must be replicated DTensor or "
719                     "normal eager Tensor.");
720 
721     // If input to CopyToMesh is a DTensor, we use the first local tensor as
722     // input tensor handle to invoke copy.
723     input_tensor = t->get_tensor(0);
724   }
725 
726   auto it = mesh_to_device_map_.find(target_mesh);
727   if (it == mesh_to_device_map_.end()) {
728     RETURN_STATUS(
729         status, TF_INTERNAL,
730         "DTensor CopyToMesh target mesh is not registered. Meshes should be "
731         "automatically registered. Please file a bug. (component id: 833864)");
732   }
733 
734   const MeshWithParallelDevice* target_parallel_mesh = it->second.get();
735 
736   // Broadcast non-dtensor value to dtensor.
737   std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast(
738       context, input_tensor, *target_parallel_mesh, name_, status);
739   if (TF_GetCode(status) != TF_OK) return;
740 
741   RecordInShapeLayoutCache(*wrapper);
742   *num_outputs = 1;
743   *outputs = MakeLayoutTensorHandle(context, std::move(wrapper), status);
744 }
745 
746 namespace {
747 
748 // Verifies that all components have the same dtype and shape.
749 // The component shape will be set upon success.
VerifyPackTensorShapeAndDtype(std::vector<parallel_device::TensorHandlePtr> & components,std::vector<int64_t> * component_shape,TF_Status * status)750 void VerifyPackTensorShapeAndDtype(
751     std::vector<parallel_device::TensorHandlePtr>& components,
752     std::vector<int64_t>* component_shape, TF_Status* status) {
753   TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
754   auto size = TFE_TensorHandleNumDims(components[0].get(), status);
755   if (TF_GetCode(status) != TF_OK) return;
756   component_shape->clear();
757   component_shape->reserve(size);
758   for (int i = 0; i < size; ++i) {
759     component_shape->push_back(
760         TFE_TensorHandleDim(components[0].get(), i, status));
761     if (TF_GetCode(status) != TF_OK) return;
762   }
763 
764   // Verify that the TensorHandle's shape and dtype match all of the component
765   // shapes and dtypes.
766   for (const auto& component : components) {
767     for (int i = 0; i < component_shape->size(); ++i) {
768       int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
769       if (TF_GetCode(status) != TF_OK) return;
770       if (tensor_dim != (*component_shape)[i]) {
771         TF_SetStatus(status, TF_UNIMPLEMENTED,
772                      "Components of a PackedTensor must currently all have "
773                      "the same shape");
774         return;
775       }
776       if (TFE_TensorHandleDataType(component.get()) != dtype) {
777         TF_SetStatus(status, TF_INTERNAL,
778                      "Components of a PackedTensor must all have "
779                      "the same dtype");
780         return;
781       }
782     }
783   }
784 }
785 
786 // Verifies that all TensorHandles have rank `rank` of dtype `dtype`.
VerifyTensorRankAndDType(TFE_TensorHandle ** tensors,int num_input,int expected_rank,TF_DataType * expected_dtype,TF_Status * status)787 void VerifyTensorRankAndDType(TFE_TensorHandle** tensors, int num_input,
788                               int expected_rank, TF_DataType* expected_dtype,
789                               TF_Status* status) {
790   for (int i = 0; i < num_input; ++i) {
791     auto actual_rank = TFE_TensorHandleNumDims(tensors[i], status);
792     if (TF_GetCode(status) != TF_OK)
793       RETURN_STATUS(status, TF_INTERNAL, "Error getting rank of tensor.");
794     if (actual_rank != expected_rank)
795       RETURN_STATUS(status, TF_INVALID_ARGUMENT,
796                     "Rank of tensor did not match the expected rank.");
797     if (expected_dtype != nullptr &&
798         TFE_TensorHandleDataType(tensors[i]) != *expected_dtype)
799       RETURN_STATUS(status, TF_INVALID_ARGUMENT,
800                     "Dtype of tensor did not match the expected dtype.");
801   }
802 }
803 
804 }  // namespace
805 
Pack(TFE_Context * context,int num_inputs,TFE_TensorHandle ** inputs,const std::string & string_layout,TF_Status * status)806 TFE_TensorHandle* DTensorDevice::Pack(TFE_Context* context, int num_inputs,
807                                       TFE_TensorHandle** inputs,
808                                       const std::string& string_layout,
809                                       TF_Status* status) {
810   if (num_inputs < 1) {
811     TF_SetStatus(status, TF_INVALID_ARGUMENT,
812                  "DTensorPack requires 1 or more inputs");
813     return nullptr;
814   }
815   StatusOr<Layout> target_layout = Layout::FromString(string_layout);
816   if (!target_layout.ok()) {
817     TF_SetStatus(status, TF_INVALID_ARGUMENT,
818                  "Failed to parse layout from string layout");
819     return nullptr;
820   }
821   const Mesh& target_mesh = target_layout->mesh();
822   const MeshWithParallelDevice* target_parallel_device =
823       mesh_to_device_map_[target_mesh].get();
824   if (target_parallel_device == nullptr) {
825     TF_SetStatus(status, TF_INVALID_ARGUMENT,
826                  absl::StrCat("Required mesh : ", target_mesh.ToString(),
827                               "is not registered with DTensor ")
828                      .c_str());
829     return nullptr;
830   }
831 
832   std::unique_ptr<TensorWithLayout> packed_tensor;
833   if (is_remote_mesh(target_parallel_device->mesh_config())) {
834     // Create a dummy output for DTensorPack if inputs are on a remote mesh.
835     TF_DataType dtype = TFE_TensorHandleDataType(inputs[0]);
836     auto size = TFE_TensorHandleNumDims(inputs[0], status);
837     if (TF_GetCode(status) != TF_OK) return nullptr;
838     std::vector<int64_t> component_shape;
839     component_shape.reserve(size);
840     for (int i = 0; i < size; ++i) {
841       component_shape.push_back(TFE_TensorHandleDim(inputs[0], i, status));
842       if (TF_GetCode(status) != TF_OK) return nullptr;
843     }
844     packed_tensor = TensorWithLayout::Dummy(
845         component_shape, dtype, *target_parallel_device, *target_layout);
846 
847   } else {
848     auto local_devices = target_parallel_device->mesh_config().local_devices();
849 
850     if (num_inputs !=
851         target_parallel_device->parallel_device().num_underlying_devices()) {
852       TF_SetStatus(status, TF_INVALID_ARGUMENT,
853                    absl::StrCat("The dtensor device ", name_, " expected ",
854                                 local_devices.size(),
855                                 " inputs to DTensorPack, but got ", num_inputs)
856                        .c_str());
857       return nullptr;
858     }
859 
860     std::vector<parallel_device::TensorHandlePtr> components;
861     components.reserve(num_inputs);
862     for (int i = 0; i < num_inputs; ++i) {
863       TFE_TensorHandle* input = inputs[i];
864       const char* input_device = TFE_TensorHandleDeviceName(input, status);
865       if (TF_GetCode(status) != TF_OK) return nullptr;
866       if (name_ == input_device) {
867         TF_SetStatus(status, TF_INVALID_ARGUMENT,
868                      "Does not support packing a Tensor that is already on "
869                      "dtensor device");
870         return nullptr;
871       }
872       // If `input` is on the target device, this creates a new handle sharing
873       // the underlying data; otherwise, async copies are invoked.
874       components.emplace_back(TFE_TensorHandleCopyToDevice(
875           input, context, local_devices[i].c_str(), status));
876       if (TF_GetCode(status) != TF_OK) return nullptr;
877     }
878 
879     std::vector<int64_t> component_shape;
880     VerifyPackTensorShapeAndDtype(components, &component_shape, status);
881     if (TF_GetCode(status) != TF_OK) return nullptr;
882 
883     std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor =
884         parallel_device::ParallelTensor::FromTensorHandles(
885             target_parallel_device->parallel_device(), std::move(components),
886             status);
887     if (TF_GetCode(status) != TF_OK) return nullptr;
888 
889     if (target_layout->rank() != component_shape.size()) {
890       TF_SetStatus(
891           status, TF_INVALID_ARGUMENT,
892           absl::StrCat(
893               "Packed layout should have the same rank as the rank for each "
894               "component. The rank of each component is: ",
895               component_shape.size(),
896               ", while layout has rank: ", target_layout->rank(),
897               "\nLayout: ", target_layout->ToString(), "\n")
898               .c_str());
899       return nullptr;
900     }
901 
902     packed_tensor =
903         TensorWithLayout::Wrap(std::move(parallel_tensor),
904                                *target_parallel_device, *target_layout)
905             .ValueOrDie();
906   }
907 
908   RecordInShapeLayoutCache(*packed_tensor);
909   TFE_TensorHandle* output =
910       MakeLayoutTensorHandle(context, std::move(packed_tensor), status);
911   if (TF_GetCode(status) != TF_OK) return nullptr;
912   return output;
913 }
914 
SparsePack(TFE_Context * context,int num_inputs,TFE_TensorHandle ** indices,TFE_TensorHandle ** values,TFE_TensorHandle ** shapes,const std::string & string_layout,TF_Status * status)915 TFE_TensorHandle* DTensorDevice::SparsePack(
916     TFE_Context* context, int num_inputs, TFE_TensorHandle** indices,
917     TFE_TensorHandle** values, TFE_TensorHandle** shapes,
918     const std::string& string_layout, TF_Status* status) {
919   StatusOr<Layout> target_layout = Layout::FromString(string_layout);
920   if (!target_layout.ok()) {
921     TF_SetStatus(status, TF_INVALID_ARGUMENT,
922                  "Failed to parse layout from string layout");
923     return nullptr;
924   }
925   const Mesh& target_mesh = target_layout->mesh();
926   const MeshWithParallelDevice* target_parallel_device =
927       mesh_to_device_map_[target_mesh].get();
928   if (target_parallel_device == nullptr) {
929     TF_SetStatus(status, TF_INVALID_ARGUMENT,
930                  absl::StrCat("Required mesh : ", target_mesh.ToString(),
931                               "is not registered with DTensor ")
932                      .c_str());
933     return nullptr;
934   }
935 
936   TF_DataType tf_int64 = TF_INT64;
937   // Verify rank and dtype of shapes.
938   VerifyTensorRankAndDType(shapes, num_inputs, /*expected_rank=*/1,
939                            /*expected_dtype=*/&tf_int64, status);
940   if (TF_GetCode(status) != TF_OK) return nullptr;
941 
942   // Verify rank and dtype of indices.
943   VerifyTensorRankAndDType(indices, num_inputs, /*expected_rank=*/2,
944                            /*expected_dtype=*/&tf_int64, status);
945   if (TF_GetCode(status) != TF_OK) return nullptr;
946 
947   // Verify rank of values.
948   VerifyTensorRankAndDType(values, num_inputs, /*expected_rank=*/1,
949                            /*expected_dtype=*/nullptr, status);
950   if (TF_GetCode(status) != TF_OK) return nullptr;
951 
952   // Compute the local shape from a shape tensor.
953   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> shape_tensor(
954       TFE_TensorHandleResolve(shapes[0], status), TF_DeleteTensor);
955   if (TF_GetCode(status) != TF_OK) {
956     TF_SetStatus(
957         status, TF_GetCode(status),
958         absl::StrCat("Error resolving the tensor handle of shape tensor"
959                      ". Original message: ",
960                      TF_Message(status))
961             .c_str());
962     return nullptr;
963   }
964   int shape_tensor_size = TFE_TensorHandleDim(shapes[0], 0, status);
965   if (TF_GetCode(status) != TF_OK || shape_tensor_size <= 0) {
966     TF_SetStatus(status, TF_GetCode(status),
967                  absl::StrCat("Error computing the num dims of shape tensor",
968                               TF_Message(status))
969                      .c_str());
970     return nullptr;
971   }
972 
973   const int64_t* data =
974       static_cast<int64_t*>(TF_TensorData(shape_tensor.get()));
975   std::vector<int64_t> local_shape(data, data + shape_tensor_size);
976   if (local_shape.size() != target_layout->rank()) {
977     TF_SetStatus(
978         status, TF_INVALID_ARGUMENT,
979         absl::StrCat(
980             "Packed layout should have the same rank as the rank for each "
981             "component. The rank of each component is: ",
982             local_shape.size(),
983             ", while layout has rank: ", target_layout->rank(),
984             "\nLayout: ", target_layout->ToString(), "\n")
985             .c_str());
986     return nullptr;
987   }
988 
989   // Create the SparseTensorWithLayout.
990   std::unique_ptr<TensorWithLayout> packed_tensor;
991   if (is_remote_mesh(target_parallel_device->mesh_config())) {
992     // Create a dummy SparseTensorWithLayout.
993     packed_tensor = SparseTensorWithLayout::Dummy(
994         local_shape, *target_parallel_device, target_layout.ValueOrDie());
995   } else {
996     // Parse the indices, values, and dense_shape tensors and put them into
997     // parallel tensors, and then pack it into a single SparseTensorWithLayout.
998     auto local_devices = target_parallel_device->mesh_config().local_devices();
999 
1000     std::vector<parallel_device::TensorHandlePtr> indices_components;
1001     std::vector<parallel_device::TensorHandlePtr> values_components;
1002     std::vector<parallel_device::TensorHandlePtr> dense_shapes_components;
1003 
1004     // Just a nice trick to make code cleaner to pack each of indices, values,
1005     // shapes.
1006     std::vector<std::vector<parallel_device::TensorHandlePtr>*> components{
1007         &indices_components, &values_components, &dense_shapes_components};
1008     std::vector<TFE_TensorHandle**> input_vectors{indices, values, shapes};
1009     for (int component_index = 0; component_index < 3; ++component_index) {
1010       components[component_index]->reserve(num_inputs);
1011       TFE_TensorHandle** inputs = input_vectors[component_index];
1012       for (int i = 0; i < num_inputs; ++i) {
1013         const char* input_device =
1014             TFE_TensorHandleDeviceName(inputs[i], status);
1015         if (TF_GetCode(status) != TF_OK) return nullptr;
1016         if (name_ == input_device) {
1017           TF_SetStatus(status, TF_INVALID_ARGUMENT,
1018                        "Does not support packing a Tensor that is already on "
1019                        "dtensor device.");
1020           return nullptr;
1021         }
1022 
1023         components[component_index]->emplace_back(TFE_TensorHandleCopyToDevice(
1024             inputs[i], context, local_devices[i].c_str(), status));
1025         if (TF_GetCode(status) != TF_OK) return nullptr;
1026       }
1027     }
1028     std::unique_ptr<parallel_device::ParallelTensor> parallel_indices_tensor =
1029         parallel_device::ParallelTensor::FromTensorHandles(
1030             target_parallel_device->parallel_device(),
1031             std::move(indices_components), status);
1032 
1033     std::unique_ptr<parallel_device::ParallelTensor> parallel_values_tensor =
1034         parallel_device::ParallelTensor::FromTensorHandles(
1035             target_parallel_device->parallel_device(),
1036             std::move(values_components), status);
1037 
1038     std::unique_ptr<parallel_device::ParallelTensor>
1039         parallel_dense_shapes_tensor =
1040             parallel_device::ParallelTensor::FromTensorHandles(
1041                 target_parallel_device->parallel_device(),
1042                 std::move(dense_shapes_components), status);
1043 
1044     if (TF_GetCode(status) != TF_OK) return nullptr;
1045     packed_tensor =
1046         SparseTensorWithLayout::Wrap(std::move(parallel_indices_tensor),
1047                                      std::move(parallel_values_tensor),
1048                                      std::move(parallel_dense_shapes_tensor),
1049                                      *target_parallel_device,
1050                                      target_layout.ValueOrDie(), local_shape)
1051             .ValueOrDie();
1052   }
1053 
1054   RecordInShapeLayoutCache(*packed_tensor);
1055   TFE_TensorHandle* output =
1056       MakeLayoutTensorHandle(context, std::move(packed_tensor), status);
1057   if (TF_GetCode(status) != TF_OK) return nullptr;
1058   return output;
1059 }
1060 
IsSparseDTensor(TFE_Context * context,TFE_TensorHandle * input,TF_Status * status)1061 bool DTensorDevice::IsSparseDTensor(TFE_Context* context,
1062                                     TFE_TensorHandle* input,
1063                                     TF_Status* status) {
1064   const char* input_device = TFE_TensorHandleDeviceName(input, status);
1065   if (input_device != name_) {
1066     TF_SetStatus(
1067         status, TF_INVALID_ARGUMENT,
1068         "DTensorSparseUnpack expects a tensor placed on the DTensor device.");
1069     return false;
1070   }
1071   TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
1072       TFE_TensorHandleDevicePointer(input, status));
1073   if (TF_GetCode(status) != TF_OK) return false;
1074   return t->tensor_type() == TensorType::kSparse;
1075 }
1076 
UpdateOutputLayoutsWithSameShapePolicy(const std::vector<PartialTensorShape> & global_output_shapes,const absl::flat_hash_set<Mesh> & input_meshes,absl::string_view op_name,tensorflow::Graph * graph,std::vector<const Layout * > * output_layouts,TF_Status * status)1077 void DTensorDevice::UpdateOutputLayoutsWithSameShapePolicy(
1078     const std::vector<PartialTensorShape>& global_output_shapes,
1079     const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name,
1080     tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts,
1081     TF_Status* status) {
1082   if (!same_shape_policy_enabled_) return;
1083   // Simply do not hint if inputs span across multiple meshes.
1084   if (input_meshes.size() > 1) return;
1085 
1086   for (Node* node : graph->op_nodes()) {
1087     if (!node->IsRetval()) {
1088       continue;
1089     }
1090     int output_index;
1091     RETURN_C_STATUS_IF_NOT_OK(
1092         GetNodeAttr(node->attrs(), "index", &output_index), status);
1093     if (output_layouts->at(output_index)) {
1094       continue;
1095     }
1096 
1097     const auto& global_output_shape = global_output_shapes.at(output_index);
1098     const Layout* layout = nullptr;
1099     // TODO(b/180022708): This is useful information, we should be
1100     // able to hint to layout propagation without making it a hard
1101     // requirement
1102     //
1103     // Special cases at the moment:
1104     // - Relayout needs an exemption.
1105     // - VarHandleOp does not need hint. VarHandleOp has scalar shape so layout
1106     //   is trivial. On the other hande, downstream system "thinks' Variable has
1107     //   shape same as the pointing value. So, providing a layout based on
1108     //   VarHandleOp (scalar) might confuse the downstream system.
1109     if (op_name != std::string("Relayout") &&
1110         op_name != std::string("VarHandleOp")) {
1111       // TODO(b/162009702): Support matching between partially-known shapes.
1112       if (global_output_shape.IsFullyDefined()) {
1113         gtl::InlinedVector<int64, 4> shape_vector(
1114             global_output_shape.dim_sizes());
1115         auto layout_iterator =
1116             shape_layout_cache_.find(FingerprintShape(shape_vector));
1117         if (layout_iterator != shape_layout_cache_.end() &&
1118             layout_iterator->second.is_unique) {
1119           // We have a cached layout for this shape. Send it to MLIR.
1120           layout = &layout_iterator->second.layout;
1121           VLOG(3) << op_name << ": found a cached layout for shape "
1122                   << global_output_shape.DebugString() << ": \""
1123                   << layout->ToString() << "\"";
1124           if (input_meshes.empty() &&
1125               layout->mesh() != default_mesh_->mesh_config()) {
1126             VLOG(3) << "But we can't infer a input mesh and cached layout: "
1127                     << "mesh \"" << (layout->mesh().ToString()) << " "
1128                     << "is different than the default mesh : \""
1129                     << default_mesh_->mesh_config().ToString() << "\"\n"
1130                     << "Not applying the cached layout.";
1131           } else if (!input_meshes.empty() &&
1132                      layout->mesh() != *input_meshes.begin()) {
1133             VLOG(3)
1134                 << "But the layout mesh is different than the executing mesh: "
1135                 << "\"" << (*input_meshes.begin()).ToString() << "\"\n"
1136                 << "Not applying the cached layout.";
1137           } else {
1138             (*output_layouts)[output_index] = layout;
1139             node->AddAttr(kDefaultLayoutAttr, layout->ToString());
1140           }
1141         } else if (layout_iterator == shape_layout_cache_.end()) {
1142           VLOG(3) << op_name << ": no cached layout found for shape "
1143                   << global_output_shape.DebugString();
1144         } else {
1145           VLOG(3) << op_name << ": found multiple layouts for shape "
1146                   << global_output_shape.DebugString();
1147         }
1148       } else {
1149         VLOG(3) << op_name
1150                 << ": not applying same-shape-same-layout due to "
1151                    "not-fully-known shape "
1152                 << global_output_shape.DebugString();
1153       }
1154     }
1155   }
1156 }
1157 
1158 std::unordered_map<std::string, int>
GetFunctionCacheHitAndMissCount(TFE_Context * context,TF_Status * status) const1159 DTensorDevice::GetFunctionCacheHitAndMissCount(TFE_Context* context,
1160                                                TF_Status* status) const {
1161   return function_compilation_hits_and_misses_;
1162 }
1163 
1164 // From `graph` containing computation for all meshes, extract/select
1165 // computation for mesh specified in `function`. Returned graph is a cloned
1166 // graph with ops only for single mesh execution.
SelectGraphToExecute(const TranslatedFunction & function,const Graph & graph,std::string * stateful_partitioned_call_name)1167 StatusOr<std::unique_ptr<Graph>> SelectGraphToExecute(
1168     const TranslatedFunction& function, const Graph& graph,
1169     std::string* stateful_partitioned_call_name) {
1170   auto new_graph = std::make_unique<Graph>(graph.flib_def());
1171   CopyGraph(graph, new_graph.get());
1172   std::vector<Node*> arg_nodes;
1173   std::vector<Node*> retval_nodes;
1174   for (Node* node : new_graph->nodes()) {
1175     if (node->IsArg()) arg_nodes.emplace_back(node);
1176     if (node->IsRetval()) retval_nodes.emplace_back(node);
1177   }
1178 
1179   // Remove irrelevant function calls.
1180   for (Node* node : new_graph->nodes()) {
1181     if (node->op_def().name() != "StatefulPartitionedCall") continue;
1182 
1183     if (node->name() != function.node_to_execute->name()) {
1184       // Remove function call that does not match mesh specification and all
1185       // output retval nodes connected to the function call node.
1186       std::queue<Node*> nodes_to_remove;
1187       nodes_to_remove.push(node);
1188       while (!nodes_to_remove.empty()) {
1189         Node* n = nodes_to_remove.front();
1190         for (const Edge* out_edge : n->out_edges()) {
1191           if (out_edge->IsControlEdge()) continue;
1192           Node* out_node = out_edge->dst();
1193           if (!out_node->IsSink()) nodes_to_remove.push(out_node);
1194         }
1195         if (n->IsRetval()) {
1196           auto pos = std::find(retval_nodes.begin(), retval_nodes.end(), n);
1197           TF_RET_CHECK(pos != retval_nodes.end());
1198           retval_nodes.erase(pos);
1199         }
1200         nodes_to_remove.pop();
1201         new_graph->RemoveNode(n);
1202       }
1203     }
1204   }
1205 
1206   *stateful_partitioned_call_name = function.node_to_execute->name();
1207   VLOG(1) << "Selected call " << *stateful_partitioned_call_name;
1208 
1209   // Remove unused arg nodes in graph.
1210   for (auto it = arg_nodes.begin(); it != arg_nodes.end(); it++) {
1211     Node* arg_node = *it;
1212     bool arg_unused = true;
1213     for (const Edge* e : arg_node->out_edges()) {
1214       if (e->dst()->IsOp()) {
1215         arg_unused = false;
1216       }
1217     }
1218     if (!arg_unused) continue;
1219 
1220     new_graph->RemoveNode(arg_node);
1221     arg_nodes.erase(it--);
1222   }
1223 
1224   // Reset index attributes for arg and retval nodes.
1225   for (Node* n : new_graph->nodes()) {
1226     // Reset arg node index attributes to its position within all the arg
1227     // nodes. This should just be increasing from 0 to n where n
1228     // is the total number of arguments. Note that this definition to
1229     // the `index` attribute is different from the definition we set in
1230     // PrepareGraphForMLIR.
1231     // This attribute is needed for each arg node when converting a Graph to
1232     // a FunctionDef.
1233     if (n->IsArg()) {
1234       auto pos = std::find(arg_nodes.begin(), arg_nodes.end(), n);
1235       TF_RET_CHECK(pos != arg_nodes.end());
1236       const int new_index = std::distance(arg_nodes.begin(), pos);
1237       n->AddAttr("index", new_index);
1238     }
1239 
1240     // Reset retval nodes index attributes.
1241     if (n->IsRetval()) {
1242       auto retval_pos = std::find(retval_nodes.begin(), retval_nodes.end(), n);
1243       TF_RET_CHECK(retval_pos != retval_nodes.end());
1244       const int new_index = std::distance(retval_nodes.begin(), retval_pos);
1245       n->AddAttr("index", new_index);
1246     }
1247   }
1248 
1249   VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_to_execute_",
1250                                          *new_graph);
1251 
1252   return new_graph;
1253 }
1254 
1255 // Adds processed graph to run for each mesh computation in
1256 // `execution_functions` to function definition library.
AddExecutionFunctionDefsToFunctionDefLibrary(const absl::flat_hash_set<Node * > & control_ret_nodes,TFE_Context * context,const Graph & graph,ExecutionFunctions * execution_functions)1257 Status AddExecutionFunctionDefsToFunctionDefLibrary(
1258     const absl::flat_hash_set<Node*>& control_ret_nodes, TFE_Context* context,
1259     const Graph& graph, ExecutionFunctions* execution_functions) {
1260   // Note: We use node name instead of node pointer for comparison because
1261   // node address in the new graph is different with the original graph.
1262   absl::flat_hash_set<std::string> control_ret_names;
1263   for (auto* n : control_ret_nodes) {
1264     control_ret_names.emplace(n->name());
1265   }
1266   for (TranslatedFunction& function : execution_functions->function_list) {
1267     std::string selected_call_node_name;
1268     // TODO(bfontain): We should just try to call the functions directly rather
1269     // than wrap
1270     // Construct graph that executes only computation for `function`.
1271     TF_ASSIGN_OR_RETURN(
1272         std::unique_ptr<Graph> new_graph,
1273         SelectGraphToExecute(function, graph, &selected_call_node_name));
1274     VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_", *new_graph);
1275 
1276     // Add unique identifier based on the function we are executing to the
1277     // function/graph and convert graph to functiondef.
1278     NameAttrList func;
1279     TF_RETURN_IF_ERROR(
1280         GetNodeAttr(function.node_to_execute->attrs(), "f", &func));
1281 
1282     static std::atomic<int64_t> unique_function_number(0);
1283     function.translated_function_name =
1284         absl::StrCat(func.name(), "_", unique_function_number.fetch_add(1));
1285     auto control_ret_node_names =
1286         [&control_ret_names, &selected_call_node_name](
1287             const Node* node) -> absl::optional<std::string> {
1288       // Add the stateful partitioned call node as a control return as we need
1289       // to process any control deps inside the inner function.
1290       if (control_ret_names.contains(node->name()) ||
1291           node->name() == selected_call_node_name) {
1292         return node->name();
1293       }
1294       return absl::nullopt;
1295     };
1296 
1297     tensorflow::FunctionDef to_run;
1298     TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef(
1299         *new_graph, function.translated_function_name, control_ret_node_names,
1300         &to_run));
1301 
1302     for (const auto& out : to_run.signature().output_arg()) {
1303       function.output_dtypes.emplace_back(static_cast<TF_DataType>(out.type()));
1304     }
1305 
1306     AddDTensorFunctionAttr(to_run);
1307     TF_RETURN_IF_ERROR(tensorflow::unwrap(context)->AddFunctionDef(to_run));
1308   }
1309 
1310   return OkStatus();
1311 }
1312 
LowerToSPMDFunction(TFE_Context * context,const std::vector<TensorWithLayout * > & inputs,const DTensorOperation & doperation,const TFE_OpAttrs * attributes,const int num_outputs,const ExecutionFunctions ** execution_functions,TF_Status * status)1313 void DTensorDevice::LowerToSPMDFunction(
1314     TFE_Context* context, const std::vector<TensorWithLayout*>& inputs,
1315     const DTensorOperation& doperation, const TFE_OpAttrs* attributes,
1316     const int num_outputs, const ExecutionFunctions** execution_functions,
1317     TF_Status* status) {
1318   profiler::TraceMe activity(
1319       [&] { return "DTensorDevice::LowerToSPMDFunction"; },
1320       profiler::TraceMeLevel::kInfo);
1321   FunctionLibraryDefinition* flib_def =
1322       tensorflow::unwrap(context)->FuncLibDef();
1323   auto graph(std::make_unique<tensorflow::Graph>(flib_def));
1324   NameAttrList eager_attributes;
1325   ASSIGN_OR_RETURN_C_STATUS(eager_attributes, FetchAttributes(attributes),
1326                             status);
1327 
1328   std::vector<PartialTensorShape> global_output_shapes;
1329   std::vector<const Layout*> output_layouts;
1330   const FunctionDef* function_def = doperation.function_def;
1331   if (!function_def) {
1332     // Output layouts of an eager op (e.g. fill) must be inferred before cache
1333     // key computation, since they might depend on the current DTensorDevice
1334     // state.
1335     Status s = PrepareGraphForMlir(
1336         function_manager_, inputs, doperation, *flib_def, eager_attributes,
1337         default_layout_, graph.get(), &global_output_shapes, &output_layouts);
1338     RETURN_C_STATUS_IF_NOT_OK(s, status);
1339 
1340     // Finds all meshes the inputs are lied on.
1341     absl::flat_hash_set<Mesh> input_meshes;
1342     for (const TensorWithLayout* tensor : inputs) {
1343       if (!tensor->layout().mesh().IsEmpty()) {
1344         input_meshes.insert(tensor->layout().mesh());
1345       }
1346     }
1347     // Currently we only provide layout hints for op-by-op, since
1348     // they interact badly with layout propagation.
1349     UpdateOutputLayoutsWithSameShapePolicy(global_output_shapes, input_meshes,
1350                                            doperation.name, graph.get(),
1351                                            &output_layouts, status);
1352     if (TF_GetCode(status) != TF_OK) return;
1353   }
1354 
1355   std::pair<tensorflow::Fprint128, const ExecutionFunctions*>
1356       cache_key_and_func = function_manager_.GetCachedFunction(
1357           doperation, eager_attributes, inputs, output_layouts);
1358   *execution_functions = cache_key_and_func.second;
1359   if (*execution_functions != nullptr) {
1360     function_compilation_hits_and_misses_["hit"]++;
1361     return;
1362   } else if (function_def) {
1363     function_compilation_hits_and_misses_["miss"]++;
1364     LOG(INFO) << "DTensor cache key lookup missed for " << doperation.name
1365               << ". DTensor is (re-)computing its SPMD transformation.";
1366   }
1367 
1368   // It includes remote devices when the coordination service is enabled.
1369   const auto device_list = tensorflow::unwrap(context)->ListAllTfDevices();
1370   DeviceSet device_set;
1371   for (const auto device : device_list) device_set.AddDevice(device);
1372 
1373   if (function_def) {
1374     ASSIGN_OR_RETURN_C_STATUS(auto device_name_to_mesh_device,
1375                               PipelineSubMeshes(context), status);
1376     const bool is_pipelining_function = !device_name_to_mesh_device.empty();
1377     // For a multi-mesh function for pipelining, we take a different execution
1378     // path. Call the partitioner to lower and partition the graph into multiple
1379     // sub functions to execute (one per sub mesh).
1380     if (is_pipelining_function) {
1381       ASSIGN_OR_RETURN_C_STATUS(
1382           ExecutionFunctions functions,
1383           PipeliningPartitionerRun(&device_name_to_mesh_device, flib_def,
1384                                    &pass_runner_, *doperation.function_def,
1385                                    eager_attributes, inputs, device_set,
1386                                    num_outputs),
1387           status);
1388       *execution_functions = function_manager_.AddCachedFunction(
1389           doperation, cache_key_and_func.first, std::move(functions));
1390       return;
1391     }
1392     // Output layouts of a function are inferred by MLIR lowering. They are
1393     // not necessary for cache key computation, so run PrepareGraphForMlir after
1394     // cache key computation to reduce the overheads of running the same
1395     // function multiple times.
1396     Status s = PrepareGraphForMlir(
1397         function_manager_, inputs, doperation, *flib_def, eager_attributes,
1398         default_layout_, graph.get(), &global_output_shapes, &output_layouts);
1399     RETURN_C_STATUS_IF_NOT_OK(s, status);
1400   }
1401 
1402   absl::flat_hash_set<Node*> control_ret_nodes;
1403   // Run DTensor MLIR passes that convert input graph to SPMD version.
1404   {
1405     profiler::TraceMe activity([&] { return "DTensorDevice::RunMLIRPasses"; },
1406                                profiler::TraceMeLevel::kInfo);
1407     RETURN_C_STATUS_IF_NOT_OK(
1408         pass_runner_.RunOnGraph(device_set, doperation.is_func(), flib_def,
1409                                 &graph, control_ret_nodes,
1410                                 cache_key_and_func.first),
1411         status);
1412   }
1413   VLOG(4) << tensorflow::DumpGraphToFile("after_mlir_spmd_lowering", *graph,
1414                                          flib_def);
1415   if (flib_def->Contains(kLoadEmbeddingFn)) {
1416     Status s = InsertFunctionForTPUEmbeddingCheckpoint(
1417         status, graph.get(), inputs, kLoadEmbeddingFn);
1418     RETURN_C_STATUS_IF_NOT_OK(s, status);
1419   }
1420 
1421   // After MLIR transformations, exactly one StatefulPartitionedCall op is
1422   // returned for mesh cluster in computation. Identity all functions to execute
1423   // for each mesh and relevant input and output information.
1424   ASSIGN_OR_RETURN_C_STATUS(
1425       ExecutionFunctions functions,
1426       IdentifyAllFunctionsToExecute(*graph.get(), global_output_shapes),
1427       status);
1428 
1429   // In order to ensure that all resource assign operations as well as side
1430   // effecting ops are executed, we add identity ops before function outputs
1431   // with control rets.
1432   RETURN_C_STATUS_IF_NOT_OK(MaybeInsertIdentityNodes(function_def, graph.get()),
1433                             status);
1434 
1435   VLOG(4) << tensorflow::DumpGraphToFile("after_post_processing_graph", *graph,
1436                                          flib_def);
1437 
1438   RETURN_C_STATUS_IF_NOT_OK(
1439       AddExecutionFunctionDefsToFunctionDefLibrary(control_ret_nodes, context,
1440                                                    *graph.get(), &functions),
1441       status);
1442   functions.num_device_ids = 1;
1443   if (function_def) {
1444     for (TranslatedFunction& function : functions.function_list) {
1445       functions.function_mesh_fingerprint =
1446           FingerprintCat64(functions.function_mesh_fingerprint,
1447                            function.function_mesh.GlobalFingerprint());
1448     }
1449   }
1450 
1451   *execution_functions = function_manager_.AddCachedFunction(
1452       doperation, cache_key_and_func.first, std::move(functions));
1453 }
1454 
ExecuteFunctionAndWait(TFE_Context * context,const TranslatedFunction * function_ptr,const MeshWithParallelDevice * parallel_device_mesh,const std::vector<parallel_device::ParallelTensor * > & parallel_inputs,const int64_t step_id,const TFE_OpAttrs * attributes,TF_Status * status)1455 void DTensorDevice::ExecuteFunctionAndWait(
1456     TFE_Context* context, const TranslatedFunction* function_ptr,
1457     const MeshWithParallelDevice* parallel_device_mesh,
1458     const std::vector<parallel_device::ParallelTensor*>& parallel_inputs,
1459     const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status) {
1460   const std::string mesh_str = function_ptr->function_mesh.ToString();
1461   VLOG(4) << "Launching computation for mesh : " << mesh_str;
1462   parallel_device_mesh->parallel_device().StartExecute(
1463       context,
1464       /*inputs=*/parallel_inputs,
1465       /*operation_name=*/function_ptr->translated_function_name.c_str(),
1466       /*attributes=*/attributes,
1467       /*expected_max_outputs=*/function_ptr->local_output_shapes.size(),
1468       /*cancellation_manager=*/*cancellation_manager_,
1469       /*step_id=*/step_id);
1470 
1471   VLOG(4) << "Joining computation result from mesh : " << mesh_str;
1472   parallel_device_mesh->parallel_device().Join(
1473       function_ptr->local_output_shapes, status);
1474   VLOG(4) << "Joining status: " << TF_Message(status);
1475   if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_CANCELLED) {
1476     LOG(ERROR) << "Encountered error while executing function: "
1477                << function_ptr->translated_function_name
1478                << " for mesh : " << mesh_str
1479                << " / error : " << TF_Message(status);
1480   }
1481 
1482   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status(
1483       TF_NewStatus(), TF_DeleteStatus);
1484   AsyncWait(context, async_wait_status.get());
1485   TF_Code error_code = TF_GetCode(async_wait_status.get());
1486   if (error_code != TF_OK && error_code != TF_CANCELLED) {
1487     LOG(ERROR) << "Async status: " << TF_Message(async_wait_status.get());
1488   }
1489 }
1490 
ExecuteRegularOperation(TFE_Context * context,const std::vector<TensorWithLayout * > & inputs,const DTensorOperation & doperation,const TFE_OpAttrs * attributes,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status)1491 void DTensorDevice::ExecuteRegularOperation(
1492     TFE_Context* context, const std::vector<TensorWithLayout*>& inputs,
1493     const DTensorOperation& doperation, const TFE_OpAttrs* attributes,
1494     int* num_outputs, TFE_TensorHandle** outputs, TF_Status* status) {
1495   const ExecutionFunctions* execution_functions = nullptr;
1496 
1497   LowerToSPMDFunction(context, inputs, doperation, attributes, *num_outputs,
1498                       &execution_functions, status);
1499   if (TF_GetCode(status) != TF_OK) return;
1500 
1501   // Update input layouts for resource arguments.
1502   for (const TranslatedFunction& function :
1503        execution_functions->function_list) {
1504     for (const auto& entry : function.resource_input_layouts) {
1505       // TODO(hthu): Add an TensorWithLayout in the inputs vector at location 0
1506       // for DeviceId. This is done as the first arg is always DeviceId, and it
1507       // isn't mapped to input Tensors.
1508       const int resource_index_to_update = entry.first - 1;
1509       inputs[resource_index_to_update]->UpdateLayout(entry.second, status);
1510       if (TF_GetCode(status) != TF_OK) {
1511         RETURN_STATUS(status, TF_GetCode(status),
1512                       absl::StrCat("Attempt to update layout input arg: ",
1513                                    resource_index_to_update,
1514                                    ". Original message: ", TF_Message(status))
1515                           .c_str());
1516       }
1517     }
1518   }
1519 
1520   int num_global_outputs = 0;
1521 
1522   // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh
1523   // object. Ideally we'd just use a fingerprinted int64_t as a unique
1524   // identifier for a mesh.
1525   std::map<std::string, const MeshWithParallelDevice*>
1526       function_name_and_mesh_mapping;
1527   absl::flat_hash_set<std::string> excluded_fn_names;
1528   std::unique_ptr<const TranslatedFunction> epu_fn_ptr, load_embedding_ptr;
1529   for (const TranslatedFunction& function :
1530        execution_functions->function_list) {
1531     StatusOr<Mesh> maybe_converted_mesh = function.function_mesh;
1532     if (function.function_mesh.is_epu_mesh()) {
1533       maybe_converted_mesh = function.function_mesh.ToDeviceType("CPU");
1534     }
1535 
1536     if (!maybe_converted_mesh.ok()) {
1537       RETURN_STATUS(status, TF_INVALID_ARGUMENT,
1538                     absl::StrCat("Failed to convert mesh, get error: ",
1539                                  maybe_converted_mesh.status().error_message())
1540                         .c_str());
1541     }
1542     const Mesh& mesh = *maybe_converted_mesh;
1543     // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh
1544     // object. Ideally we'd just use a fingerprinted int64_t as a unique
1545     // identifier for a mesh.
1546     const MeshWithParallelDevice* parallel_device_mesh =
1547         mesh_to_device_map_.contains(mesh) ? mesh_to_device_map_[mesh].get()
1548                                            : default_mesh_;
1549     if (parallel_device_mesh == nullptr) {
1550       RETURN_STATUS(status, TF_INTERNAL,
1551                     "required mesh is not registered with DTensor device");
1552     }
1553     function_name_and_mesh_mapping[function.translated_function_name] =
1554         parallel_device_mesh;
1555 
1556     if (function.function_mesh.is_epu_mesh()) {
1557       if (epu_fn_ptr != nullptr) {
1558         RETURN_STATUS(status, TF_INTERNAL,
1559                       "There are more than one function defined on EPU mesh.");
1560       }
1561       epu_fn_ptr = std::make_unique<const TranslatedFunction>(function);
1562       excluded_fn_names.insert(function.translated_function_name);
1563     }
1564     if (absl::StartsWith(function.translated_function_name, kLoadEmbeddingFn)) {
1565       if (load_embedding_ptr != nullptr) {
1566         RETURN_STATUS(status, TF_INTERNAL,
1567                       "There are more than one function defined on EPU mesh.");
1568       }
1569       load_embedding_ptr = std::make_unique<const TranslatedFunction>(function);
1570       excluded_fn_names.insert(function.translated_function_name);
1571     }
1572   }
1573 
1574   // Compute the step_id based on the function_mesh_fingerprint and the
1575   // corresponding function execution counter.
1576   uint64 function_mesh_fingerprint =
1577       execution_functions->function_mesh_fingerprint;
1578   if (func_mesh_fingerprint_to_step_counter_.contains(
1579           function_mesh_fingerprint)) {
1580     func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint)++;
1581   } else {
1582     func_mesh_fingerprint_to_step_counter_.insert(
1583         {function_mesh_fingerprint, 0});
1584   }
1585   const uint64 step_id = FingerprintCat64(
1586       function_mesh_fingerprint,
1587       func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint));
1588 
1589   // Execute excluded functions in sequence.
1590   if (epu_fn_ptr != nullptr) {
1591     ExecuteFunctionAndWait(
1592         context,
1593         /*function_ptr=*/epu_fn_ptr.get(),
1594         /*parallel_device_mesh=*/
1595         function_name_and_mesh_mapping[epu_fn_ptr->translated_function_name],
1596         /*parallel_inputs=*/{}, /*step_id=*/step_id, /*attributes=*/attributes,
1597         /*status=*/status);
1598   }
1599 
1600   if (load_embedding_ptr != nullptr) {
1601     StatusOr<std::vector<parallel_device::ParallelTensor*>> parallel_inputs =
1602         PrepareEmbeddingInputs(inputs);
1603     if (!parallel_inputs.ok()) {
1604       RETURN_STATUS(status, TF_INTERNAL,
1605                     parallel_inputs.status().error_message().c_str());
1606     }
1607     ExecuteFunctionAndWait(
1608         context,
1609         /*function_ptr=*/load_embedding_ptr.get(),
1610         /*parallel_device_mesh=*/
1611         function_name_and_mesh_mapping[load_embedding_ptr
1612                                            ->translated_function_name],
1613         /*parallel_inputs=*/*parallel_inputs, /*step_id=*/step_id,
1614         /*attributes=*/attributes, /*status=*/status);
1615   }
1616 
1617   // Extract the global parallel inputs and flatten SparseTensors
1618   // into the three component tensors.
1619   std::vector<parallel_device::ParallelTensor*> global_parallel_inputs;
1620   std::vector<parallel_device::ParallelTensor*> global_parallel_sparse_inputs;
1621   absl::flat_hash_set<int> global_sparse_input_indices;
1622   for (auto input : inputs) {
1623     if (input->tensor_type() == TensorType::kSparse) {
1624       SparseTensorWithLayout* sparse_input =
1625           dynamic_cast<SparseTensorWithLayout*>(input);
1626       global_parallel_sparse_inputs.push_back(sparse_input->indices());
1627       global_parallel_sparse_inputs.push_back(sparse_input->dense_shapes());
1628       global_parallel_sparse_inputs.push_back(sparse_input->values());
1629     } else {
1630       global_parallel_inputs.push_back(input->tensor());
1631     }
1632   }
1633   // Insert SparseTensor components to the end, this is because
1634   // in the MLIR handling of SparseTensors, we place SparseTensor components
1635   // to the end of the main func arguments for a fixed ordering.
1636   global_parallel_inputs.insert(global_parallel_inputs.end(),
1637                                 global_parallel_sparse_inputs.begin(),
1638                                 global_parallel_sparse_inputs.end());
1639 
1640   // Execute all functions in parallel.
1641   for (const TranslatedFunction& function :
1642        execution_functions->function_list) {
1643     const Mesh& mesh = function.function_mesh;
1644     const std::string& translated_function_name =
1645         function.translated_function_name;
1646 
1647     num_global_outputs += function.local_output_shapes.size();
1648 
1649     if (is_remote_mesh(mesh) ||
1650         (excluded_fn_names.find(translated_function_name) !=
1651          excluded_fn_names.end())) {
1652       // Skip execution for a translated function has remote mesh or when it is
1653       // excluded.
1654       continue;
1655     }
1656 
1657     const MeshWithParallelDevice* parallel_device_mesh =
1658         function_name_and_mesh_mapping[translated_function_name];
1659 
1660     // Gather the local inputs for this function.
1661     std::vector<parallel_device::ParallelTensor*> parallel_inputs;
1662     parallel_inputs.reserve(inputs.size() + 1);
1663     auto input_mapping = function.input_index_map;
1664 
1665     // We sort here because by this time, the function graph we are executing
1666     // is a reduced version of the main function, that includes the
1667     // StatefulPartitionedCall that we are executing for this mesh.
1668     // Thus, the ordering is the same as the main function ordering, which
1669     // is sorted increasingly.
1670     std::sort(input_mapping.begin(), input_mapping.end());
1671 
1672     for (const int global_index : input_mapping) {
1673       auto input_index = global_index - execution_functions->num_device_ids;
1674 
1675       if (global_index < execution_functions->num_device_ids) {
1676         parallel_inputs.push_back(
1677             parallel_device_mesh->DeviceIDs(context, status));
1678         if (TF_GetCode(status) != TF_OK) return;
1679       } else {
1680         parallel_inputs.push_back(global_parallel_inputs[input_index]);
1681       }
1682     }
1683 
1684     VLOG(4) << "Launching computation for mesh : " << mesh.ToString();
1685     parallel_device_mesh->parallel_device().StartExecute(
1686         context, parallel_inputs, translated_function_name.c_str(), attributes,
1687         /*expected_max_outputs=*/function.local_output_shapes.size(),
1688         *cancellation_manager_, /*step_id=*/step_id);
1689   }
1690 
1691   *num_outputs = num_global_outputs;
1692   std::vector<std::unique_ptr<TensorWithLayout>> typed_outputs;
1693   typed_outputs.resize(num_global_outputs);
1694 
1695   // Join all mesh computation together.
1696   // TODO(b/177932563): Expose cancel logic to handle failures.
1697   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> join_status(
1698       TF_NewStatus(), TF_DeleteStatus);
1699   for (const TranslatedFunction& function :
1700        execution_functions->function_list) {
1701     // Skip execution for a function when it's excluded.
1702     if (excluded_fn_names.contains(function.translated_function_name)) {
1703       continue;
1704     }
1705     const Mesh& mesh = function.function_mesh;
1706     // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh
1707     // object. Ideally we'd just use a fingerprinted int64_t as a unique
1708     // identifier for a mesh.
1709     const MeshWithParallelDevice* parallel_device_mesh =
1710         function_name_and_mesh_mapping[function.translated_function_name];
1711 
1712     std::vector<std::unique_ptr<TensorWithLayout>> output_with_layout;
1713     output_with_layout.reserve(function.output_index_map.size());
1714     if (is_remote_mesh(mesh)) {
1715       // Create dummy outputs on a remote mesh.
1716       for (int i = 0; i < function.output_index_map.size(); ++i) {
1717         const auto dim_sizes = function.local_output_shapes.at(i).dim_sizes();
1718         std::vector<int64_t> local_shape =
1719             std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end());
1720         TF_DataType dtype =
1721             static_cast<TF_DataType>(function.output_dtypes.at(i));
1722         auto remote_output =
1723             TensorWithLayout::Dummy(local_shape, dtype, *parallel_device_mesh,
1724                                     function.output_layouts[i]);
1725         output_with_layout.push_back(std::move(remote_output));
1726       }
1727     } else {
1728       VLOG(4) << "Joining computation result from mesh : " << mesh.ToString();
1729       auto result = parallel_device_mesh->parallel_device().Join(
1730           function.local_output_shapes, status);
1731       if (TF_GetCode(join_status.get()) != TF_OK &&
1732           // Preserve the first failure we see, but only if it is a real failure
1733           // and not a cancellation (which was probably triggered by the error
1734           // we want to propagate).
1735           (TF_GetCode(status) == TF_OK ||
1736            TF_GetCode(join_status.get()) != TF_CANCELLED)) {
1737         continue;
1738       }
1739       if (TF_GetCode(status) != TF_OK) {
1740         if (TF_GetCode(status) != TF_CANCELLED) {
1741           LOG(ERROR) << "Encountered error while executing function: "
1742                      << function.translated_function_name
1743                      << " for mesh : " << mesh.ToString()
1744                      << " / error : " << TF_Message(status);
1745         }
1746         TF_SetStatus(join_status.get(), TF_GetCode(status), TF_Message(status));
1747         continue;
1748       }
1749 
1750       for (int i = 0; i < result->size(); ++i) {
1751         ASSIGN_OR_RETURN_C_STATUS(
1752             auto local_output,
1753             TensorWithLayout::Wrap(std::move((*result)[i]),
1754                                    *parallel_device_mesh,
1755                                    function.output_layouts[i]),
1756             status);
1757         output_with_layout.push_back(std::move(local_output));
1758       }
1759     }
1760 
1761     for (int i = 0; i < function.output_index_map.size(); ++i) {
1762       // TODO(b/162744844): Generalize this pattern so that the extraction is
1763       // not special cased.
1764       if (function.shape_output_metadata.find(i) !=
1765           function.shape_output_metadata.end()) {
1766         output_with_layout[i]->set_input_layout_for_shape_op_result(
1767             function.shape_output_metadata.at(i));
1768       }
1769 
1770       RecordInShapeLayoutCache(*output_with_layout[i]);
1771       typed_outputs[function.output_index_map[i]] =
1772           std::move(output_with_layout[i]);
1773     }
1774   }
1775   if (TF_GetCode(join_status.get()) != TF_OK) {
1776     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status(
1777         TF_NewStatus(), TF_DeleteStatus);
1778     AsyncWait(context, async_wait_status.get());
1779     TF_Code error_code = TF_GetCode(async_wait_status.get());
1780     if (error_code != TF_OK && error_code != TF_CANCELLED) {
1781       // Ignore the AsyncWait() status return since we already have a bad status
1782       // to propagate. We've just canceled a bunch of operations, so we expect
1783       // cancellation status returns. We'll log anything else just to be safe.
1784       LOG(ERROR) << "Error executing " << doperation.name << " "
1785                  << TF_Message(async_wait_status.get());
1786     }
1787 
1788     TF_SetStatus(status, TF_GetCode(join_status.get()),
1789                  TF_Message(join_status.get()));
1790     return;
1791   }
1792   if (VLOG_IS_ON(2)) {
1793     LOG(INFO) << "Executed " << doperation.name << ", got "
1794               << typed_outputs.size() << " outputs:";
1795     for (const std::unique_ptr<TensorWithLayout>& output : typed_outputs) {
1796       LOG(INFO) << "  " << output->DebugString();
1797     }
1798   }
1799   if (doperation.name == std::string("VarHandleOp")) {
1800     // For new variables, set the dereferenced shape/dtype so we can pass it in
1801     // as _handle_dtype and _handle_shape in the future.
1802     //
1803     // Note that VarHandleOps generated by `tf.Variable` objects are always run
1804     // eagerly, which is almost all of the op's usage in TF2. Theoretically a
1805     // user could run it in a tf.function via tf.raw_ops.VarHandleOp, return it
1806     // from that function, and add it as an input to another, and it would
1807     // currently be missing handle information.
1808     if (typed_outputs.size() != 1) {
1809       RETURN_STATUS(status, TF_INTERNAL,
1810                     "Expected one output from VarHandleOp");
1811     }
1812     NameAttrList name_and_attrs;
1813     ASSIGN_OR_RETURN_C_STATUS(name_and_attrs, FetchAttributes(attributes),
1814                               status);
1815 
1816     typed_outputs[0]->UpdateShapeAndDType(
1817         name_and_attrs.attr().at("shape").shape(),
1818         name_and_attrs.attr().at("dtype").type(), status);
1819     if (TF_GetCode(status) != TF_OK) return;
1820   }
1821 
1822   for (int i = 0; i < *num_outputs; ++i) {
1823     outputs[i] =
1824         MakeLayoutTensorHandle(context, std::move(typed_outputs[i]), status);
1825     if (TF_GetCode(status) != TF_OK) return;
1826   }
1827 }
1828 
Execute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status)1829 void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs,
1830                             TFE_TensorHandle** outputs, TF_Status* status) {
1831   TFE_Context* context = TFE_OpGetContext(original_op, status);
1832   if (TF_GetCode(status) != TF_OK) return;
1833   const char* operation_name = TFE_OpGetName(original_op, status);
1834   if (TF_GetCode(status) != TF_OK) return;
1835   const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
1836   int num_inputs = TFE_OpGetFlatInputCount(original_op, status);
1837   if (TF_GetCode(status) != TF_OK) return;
1838   std::vector<TFE_TensorHandle*> inputs_vector;
1839   inputs_vector.reserve(num_inputs);
1840   for (int input_index = 0; input_index < num_inputs; ++input_index) {
1841     TFE_TensorHandle* input =
1842         TFE_OpGetFlatInput(original_op, input_index, status);
1843     if (TF_GetCode(status) != TF_OK) return;
1844     inputs_vector.push_back(input);
1845   }
1846   TFE_TensorHandle** inputs = inputs_vector.data();
1847 
1848   DTensorOperation dtensor_operation{};
1849   dtensor_operation.name = operation_name;
1850   {
1851     dtensor_operation.function_def =
1852         tensorflow::unwrap(context)->FindFunctionDef(operation_name);
1853   }
1854 
1855   // First handle DTensor-specific virtual operations.
1856   bool is_op_handled = false;
1857   MaybeHandleDTensorCustomOps(operation_name, num_inputs, attributes, context,
1858                               inputs, num_outputs, outputs, &is_op_handled,
1859                               status);
1860   if (is_op_handled) return;
1861 
1862   // This isn't a special op, so we'll defer to TFE_Execute to actually execute
1863   // it, but we'll also run DTensor MLIR passes and propagate the layout.
1864   std::vector<TensorWithLayout*> typed_inputs;
1865   std::vector<std::unique_ptr<TensorWithLayout>> inputs_with_no_layout;
1866 
1867   // Record a unique mesh identified through all inputs that's already on
1868   // DTensor device. If we can identify a single mesh, the same mesh is used as
1869   // the mesh to broadcast non-dtensor inputs.
1870   absl::flat_hash_set<Mesh> input_meshes;
1871   std::vector<int> not_on_device_input_indices;
1872 
1873   typed_inputs.resize(num_inputs);
1874   for (int j = 0; j < num_inputs; ++j) {
1875     TFE_TensorHandle* input = inputs[j];
1876     const char* input_device = TFE_TensorHandleDeviceName(input, status);
1877     if (TF_GetCode(status) != TF_OK) return;
1878     if (name_ != input_device) {
1879       not_on_device_input_indices.push_back(j);
1880       continue;
1881     }
1882     // Handle input which is on DTensor device already.
1883     TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>(
1884         TFE_TensorHandleDevicePointer(input, status));
1885     if (TF_GetCode(status) != TF_OK) return;
1886 
1887     // VarHandleOp runs on empty mesh, and that isn't registered with device.
1888     if (!t->layout().mesh().IsEmpty()) {
1889       input_meshes.insert(t->layout().mesh());
1890     }
1891     // Remote mesh inputs are not able to be read and evaluated.
1892     if (!is_remote_mesh(t->layout().mesh()) && !t->const_value().has_value()) {
1893       std::optional<NodeDef> const_value =
1894           ExtractSmallTensorValue(context, input, t->layout(), status);
1895       if (TF_GetCode(status) != TF_OK) return;
1896       if (const_value.has_value()) {
1897         t->set_const_value(const_value.value());
1898       }
1899     }
1900     typed_inputs[j] = t;
1901   }
1902 
1903   // If a unique mesh is identified across all inputs, we use that mesh as the
1904   // mesh to broadcast to. Otherwise we fallback to default mesh.
1905   const MeshWithParallelDevice* broadcast_mesh =
1906       input_meshes.size() == 1
1907           ? mesh_to_device_map_[*input_meshes.begin()].get()
1908           : default_mesh_;
1909   if (!broadcast_mesh) {
1910     RETURN_STATUS(status, TF_INVALID_ARGUMENT,
1911                   "No mesh has been registered to DTensor. Use copy_to_mesh to "
1912                   "explicit specify a mesh instead.");
1913   }
1914   for (int not_on_device_input_index : not_on_device_input_indices) {
1915     TFE_TensorHandle* input = inputs[not_on_device_input_index];
1916     // DTensor creation should be explicit, with some exceptions for usability
1917     // (scalars/shapes/slice specs/etc.) Here we do some trivial validation to
1918     // enforce this rule.
1919     int num_dims = TFE_TensorHandleNumDims(input, status);
1920     if (TF_GetCode(status) != TF_OK) return;
1921     int64_t num_elements = TFE_TensorHandleNumElements(input, status);
1922     if (TF_GetCode(status) != TF_OK) return;
1923     TF_DataType dtype = TFE_TensorHandleDataType(input);
1924     const bool small_int_tensor = num_elements < kSmallTensorThreshold &&
1925                                   (dtype == TF_INT32 || dtype == TF_INT64);
1926     if (!(num_dims == 0 || dtype == TF_STRING || small_int_tensor)) {
1927       std::vector<int64_t> tensor_shape(TensorShapeAsVector(input, status));
1928       if (TF_GetCode(status) != TF_OK) return;
1929       RETURN_STATUS(
1930           status, TF_UNIMPLEMENTED,
1931           absl::StrCat(
1932               "The op/function ", operation_name,
1933               " got a regular tensor for input ", not_on_device_input_index,
1934               " (shape ", ShapeToDebugString(tensor_shape),
1935               ") but was expecting a DTensor. Currently only scalars and "
1936               "small integer/string tensors are auto-broadcast to "
1937               "DTensors. For other tensors, please use copy_to_mesh to "
1938               "make a DTensor explicitly; note that this may be slow if it "
1939               "happens frequently.")
1940               .c_str());
1941     }
1942     // Construct temporary TensorWithLayout objects for inputs that didn't
1943     // have any to start. These are owned by the `inputs_with_no_layout`
1944     // vector, whereas the input `TFE_TensorHandle`s maintain ownership for
1945     // inputs that already had layouts (and therefor had TensorWithLayout
1946     // objects).
1947     std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast(
1948         context, input, *broadcast_mesh, name_, status);
1949     if (TF_GetCode(status) != TF_OK) return;
1950     if (!ShouldFoldInputArgument(dtensor_operation.name,
1951                                  /*input_index=*/not_on_device_input_index)) {
1952       wrapper->reset_const_value();
1953     }
1954     typed_inputs[not_on_device_input_index] = wrapper.get();
1955     inputs_with_no_layout.emplace_back(wrapper.release());
1956   }
1957 
1958   ExecuteRegularOperation(context, typed_inputs, dtensor_operation, attributes,
1959                           num_outputs, outputs, status);
1960 }
1961 
ExecuteOnDTensorDevice(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * status,void * device_info)1962 void ExecuteOnDTensorDevice(const TFE_Op* original_op, int* num_outputs,
1963                             TFE_TensorHandle** outputs, TF_Status* status,
1964                             void* device_info) {
1965   DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info);
1966   dev->Execute(original_op, num_outputs, outputs, status);
1967 }
1968 
DeleteDTensorDevice(void * device_info)1969 void DeleteDTensorDevice(void* device_info) {
1970   delete static_cast<DTensorDevice*>(device_info);
1971 }
1972 
CopyToDTensorDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)1973 TFE_TensorHandle* CopyToDTensorDevice(TFE_Context* context,
1974                                       TFE_TensorHandle* tensor,
1975                                       TF_Status* status, void* device_info) {
1976   TF_SetStatus(status, TF_UNIMPLEMENTED,
1977                "Trying to copy a tensor on to a DTensor mesh without a layout "
1978                "(use the CopyToMesh op for now).");
1979   return nullptr;
1980 }
1981 
CopyFromDTensorDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)1982 TFE_TensorHandle* CopyFromDTensorDevice(TFE_Context* context,
1983                                         TFE_TensorHandle* tensor,
1984                                         const char* target_device_name,
1985                                         TF_Status* status, void* device_info) {
1986   TensorWithLayout* typed_input = reinterpret_cast<TensorWithLayout*>(
1987       TFE_TensorHandleDevicePointer(tensor, status));
1988   if (!tensorflow::dtensor::Layout(typed_input->layout()).IsFullyReplicated()) {
1989     TF_SetStatus(status, TF_UNIMPLEMENTED,
1990                  "Trying to copy a non-replicated DTensor is not supported.");
1991     return nullptr;
1992   }
1993   if (typed_input->tensor()->dtype() == TF_RESOURCE) {
1994     TF_SetStatus(status, TF_UNIMPLEMENTED,
1995                  "Trying to copy a DTensor resource handle is not supported.");
1996     return nullptr;
1997   }
1998   DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info);
1999   // Since operations are executed asynchronously, the operation which should
2000   // produce the tensor we're trying to copy off the DTensor device may be
2001   // canceled due to a failure on another device. If so, we want to report the
2002   // failure that caused the cancellation, not the cancellation itself. This
2003   // requires blocking waiting for other devices to flush their execution
2004   // queues.
2005   // Note that we also only need to sync the threads on the parallel_device()
2006   // directly, or a context level sync might cause unintentional deadlocks when
2007   // grabbing locks on other threads.
2008   dev->AsyncWait(context, status);
2009   if (TF_GetCode(status) != TF_OK) return nullptr;
2010   return TFE_TensorHandleCopySharingTensor(typed_input->get_tensor(0), status);
2011 }
2012 
AllocateDTensorDevice(absl::string_view device_name,TFE_CustomDevice * device,void ** device_info)2013 void AllocateDTensorDevice(absl::string_view device_name,
2014                            TFE_CustomDevice* device, void** device_info) {
2015   device->copy_tensor_to_device = &CopyToDTensorDevice;
2016   device->copy_tensor_from_device = &CopyFromDTensorDevice;
2017   device->delete_device = &DeleteDTensorDevice;
2018   device->execute = &ExecuteOnDTensorDevice;
2019   *device_info = new DTensorDevice(device_name);
2020 }
2021 
AddMesh(const std::string & serialized_mesh,void * device_info,bool is_async,bool is_host_mesh,TF_Status * status)2022 void AddMesh(const std::string& serialized_mesh, void* device_info,
2023              bool is_async, bool is_host_mesh, TF_Status* status) {
2024   auto mesh_config_or_status = Mesh::FromString(serialized_mesh);
2025   if (!mesh_config_or_status.ok()) {
2026     TF_SetStatus(status, TF_INTERNAL,
2027                  absl::StrCat("Failed to parse mesh config. ",
2028                               mesh_config_or_status.status().error_message())
2029                      .c_str());
2030     return;
2031   }
2032   auto mesh_config = mesh_config_or_status.ValueOrDie();
2033   std::vector<std::string> underlying_devices;
2034   underlying_devices.insert(underlying_devices.end(),
2035                             mesh_config.local_devices().begin(),
2036                             mesh_config.local_devices().end());
2037   // DTensor uses multi-client setup which doesn't use remote eager, so we can
2038   // enable eager async execution in ParallelDevice.
2039   std::unique_ptr<tensorflow::parallel_device::ParallelDevice> parallel(
2040       new tensorflow::parallel_device::ParallelDevice(underlying_devices,
2041                                                       is_async));
2042 
2043   std::string composite_device_name;
2044   if (absl::StartsWith(mesh_config.name(), kPipelineMeshNamePrefix)) {
2045     composite_device_name = std::string(
2046         absl::StripPrefix(mesh_config.name(), kPipelineMeshNamePrefix));
2047   }
2048 
2049   auto mesh = std::make_unique<MeshWithParallelDevice>(
2050       std::move(mesh_config), std::move(parallel), composite_device_name);
2051   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2052   device->AddMesh(std::move(mesh), is_host_mesh);
2053 }
2054 
ExperimentalSetDefaultLayout(const std::string & serialized_layout,void * device_info,TF_Status * status)2055 void ExperimentalSetDefaultLayout(const std::string& serialized_layout,
2056                                   void* device_info, TF_Status* status) {
2057   StatusOr<Layout> layout = Layout::FromString(serialized_layout);
2058   if (!layout.ok()) {
2059     RETURN_STATUS(status, TF_INTERNAL, layout.status().error_message().c_str());
2060   }
2061   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2062   device->SetDefaultLayout(layout.ValueOrDie());
2063 }
2064 
ExperimentalClearDefaultLayout(void * device_info,TF_Status * status)2065 void ExperimentalClearDefaultLayout(void* device_info, TF_Status* status) {
2066   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2067   device->ClearDefaultLayout();
2068 }
2069 
ExperimentalSetDefaultMesh(const std::string & serialized_mesh,void * device_info,TF_Status * status)2070 void ExperimentalSetDefaultMesh(const std::string& serialized_mesh,
2071                                 void* device_info, TF_Status* status) {
2072   StatusOr<Mesh> mesh = Mesh::FromString(serialized_mesh);
2073   if (!mesh.ok()) {
2074     RETURN_STATUS(status, TF_INTERNAL, mesh.status().error_message().c_str());
2075   }
2076   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2077   device->SetDefaultMesh(mesh.ValueOrDie());
2078 }
2079 
ExperimentalClearDefaultMesh(void * device_info,TF_Status * status)2080 void ExperimentalClearDefaultMesh(void* device_info, TF_Status* status) {
2081   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2082   device->ClearDefaultMesh();
2083 }
2084 
SetSameShapePolicy(void * device_info,bool enabled)2085 void SetSameShapePolicy(void* device_info, bool enabled) {
2086   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2087   device->SetSameShapePolicy(enabled);
2088 }
2089 
SetTPUCoreIDs(const std::string & mesh_name,const std::vector<int> & tpu_core_ids,void * device_info,TF_Status * status)2090 void SetTPUCoreIDs(const std::string& mesh_name,
2091                    const std::vector<int>& tpu_core_ids, void* device_info,
2092                    TF_Status* status) {
2093   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2094   RETURN_C_STATUS_IF_NOT_OK(device->SetTPUCoreIDs(mesh_name, tpu_core_ids),
2095                             status);
2096 }
2097 
ClearTPUCoreIDs(void * device_info)2098 void ClearTPUCoreIDs(void* device_info) {
2099   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2100   device->ClearTPUCoreIDs();
2101 }
2102 
TPUCoreIDsToLocations(TFE_Context * context,const std::vector<int> & tpu_core_ids,void * device_info)2103 std::vector<std::vector<int>> TPUCoreIDsToLocations(
2104     TFE_Context* context, const std::vector<int>& tpu_core_ids,
2105     void* device_info) {
2106   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2107   return device->TPUCoreIDsToLocations(context, tpu_core_ids);
2108 }
2109 
TPUCoreLocationsToIDs(TFE_Context * context,const std::vector<std::vector<int>> & tpu_core_locations,void * device_info)2110 std::vector<int> TPUCoreLocationsToIDs(
2111     TFE_Context* context,
2112     const std::vector<std::vector<int>>& tpu_core_locations,
2113     void* device_info) {
2114   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2115   return device->TPUCoreLocationsToIDs(context, tpu_core_locations);
2116 }
2117 
Pack(TFE_Context * context,int num_inputs,TFE_TensorHandle ** inputs,const std::string & string_layout,void * device_info,TF_Status * status)2118 TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs,
2119                        TFE_TensorHandle** inputs,
2120                        const std::string& string_layout, void* device_info,
2121                        TF_Status* status) {
2122   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2123   return device->Pack(context, num_inputs, inputs, string_layout, status);
2124 }
2125 
Unpack(TFE_Context * context,TFE_TensorHandle * input,void * device_info,TF_Status * status)2126 std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context,
2127                                       TFE_TensorHandle* input,
2128                                       void* device_info, TF_Status* status) {
2129   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2130   return device->Unpack(context, input, status);
2131 }
2132 
FetchLayout(TFE_Context * context,TFE_TensorHandle * input,void * device_info,TF_Status * status)2133 std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input,
2134                         void* device_info, TF_Status* status) {
2135   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2136   return device->FetchLayout(context, input, status);
2137 }
2138 
SparsePack(TFE_Context * context,int num_inputs,TFE_TensorHandle ** indices,TFE_TensorHandle ** values,TFE_TensorHandle ** shapes,const std::string & string_layout,void * device_info,TF_Status * status)2139 TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs,
2140                              TFE_TensorHandle** indices,
2141                              TFE_TensorHandle** values,
2142                              TFE_TensorHandle** shapes,
2143                              const std::string& string_layout,
2144                              void* device_info, TF_Status* status) {
2145   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2146   return device->SparsePack(context, num_inputs, indices, values, shapes,
2147                             string_layout, status);
2148 }
2149 
IsSparseDTensor(TFE_Context * context,TFE_TensorHandle * input,void * device_info,TF_Status * status)2150 bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input,
2151                      void* device_info, TF_Status* status) {
2152   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2153   return device->IsSparseDTensor(context, input, status);
2154 }
2155 
GetFunctionCacheHitAndMissCount(TFE_Context * context,void * device_info,TF_Status * status)2156 std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount(
2157     TFE_Context* context, void* device_info, TF_Status* status) {
2158   DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info);
2159   return device->GetFunctionCacheHitAndMissCount(context, status);
2160 }
2161 }  // namespace dtensor
2162 }  // namespace tensorflow
2163