xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/tensor_layout.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/cc/tensor_layout.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <numeric>
23 #include <set>
24 #include <string>
25 #include <string_view>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/types/optional.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/lib/math/math_util.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/fingerprint.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/statusor.h"
41 #include "tensorflow/core/util/device_name_utils.h"
42 #include "tensorflow/dtensor/cc/dstatus.h"
43 #include "tensorflow/dtensor/proto/layout.pb.h"
44 
45 namespace tensorflow {
46 namespace dtensor {
47 
48 constexpr const char* Layout::kUnshardedDim;
49 constexpr const char* Layout::kAny;
50 constexpr const char* Layout::kEmptyLayoutString;
51 constexpr const char* Layout::kMatch;
52 constexpr const char* Mesh::kEmptyMeshString;
53 
54 namespace {
55 // Obtain all possible forms of indexing a mesh.
56 //
57 // e.g. given a mesh with dimensions [x=2, y=3], returns {
58 //   [0, 0], [0, 1], [0, 2],
59 //   [1, 0], [1, 1], [1, 2]
60 // }
ComputeDeviceLocations(const Mesh * mesh)61 inline std::vector<DeviceLocation> ComputeDeviceLocations(const Mesh* mesh) {
62   std::vector<DeviceLocation> mesh_locs(mesh->size());
63   for (size_t i = 0; i < mesh->size(); ++i)
64     mesh_locs[i] = *(mesh->device_location(i));
65   return mesh_locs;
66 }
67 }  // namespace
68 
69 namespace {
70 // Expands a ShardVector into the size defined in new_num_shards_per_dim.
71 //
72 // For example, the inputs:
73 //    - shard_vec: shards = [(1,1)] num_shards_per_dim = [1,1]
74 //    - new_num_shards_per_dim = [2,2]
75 //
76 // Would lead to:
77 // shard_vec: shards = [(1,1),(1,2),(2,1),(2,2)] num_shards_per_dim = [2,2]
78 //
79 // This is used to check whether two ShardVectors contain the same information
80 // while having different number of shards per dimension. The two ShardVectors
81 // above are an example of this.
ExpandShardVector(const ShardVector & shard_vec,const std::vector<int> & new_num_shards_per_dim)82 ShardVector ExpandShardVector(const ShardVector& shard_vec,
83                               const std::vector<int>& new_num_shards_per_dim) {
84   if (shard_vec.shards.empty()) return shard_vec;
85 
86   // Takes a single shard and expands it into multiple shards.
87   auto ExpandShard = [shard_vec, new_num_shards_per_dim](
88                          const Shard& shard,
89                          int dim_ind) -> std::vector<Shard> {
90     int original_dim_size = shard_vec.num_shards_per_dim[dim_ind];
91     int new_dim_size = new_num_shards_per_dim[dim_ind];
92     int size_ratio = new_dim_size / original_dim_size;
93 
94     std::vector<Shard> expanded_shards;
95     expanded_shards.reserve(size_ratio);
96     for (int i = 0; i < size_ratio; ++i) {
97       int original_coord = shard[dim_ind];
98       int shifted_coord = (original_coord - 1) * size_ratio + 1 + i;
99       // Copy original shard, then modify it.
100       Shard new_shard = shard;
101       new_shard[dim_ind] = shifted_coord;
102       expanded_shards.push_back(new_shard);
103     }
104     return expanded_shards;
105   };
106   // Iterates over the dimensions of the shard, expanding at each
107   // dimension.
108   std::vector<Shard> total_expanded_shards = shard_vec.shards;
109   for (int dim_ind = 0; dim_ind < new_num_shards_per_dim.size(); ++dim_ind) {
110     std::vector<Shard> dim_expanded_shards;
111     for (const auto& shard : total_expanded_shards) {
112       std::vector<Shard> expanded_shards = ExpandShard(shard, dim_ind);
113       // Concatenate newly created shards.
114       dim_expanded_shards.insert(dim_expanded_shards.end(),
115                                  expanded_shards.begin(),
116                                  expanded_shards.end());
117     }
118     // Copy newly created shards and delete old ones.
119     total_expanded_shards = dim_expanded_shards;
120   }
121   std::sort(total_expanded_shards.begin(), total_expanded_shards.end());
122   ShardVector expanded_shard_vec;
123   expanded_shard_vec.shards = total_expanded_shards;
124   expanded_shard_vec.num_shards_per_dim = new_num_shards_per_dim;
125   return expanded_shard_vec;
126 }
127 }  // namespace
128 
operator ==(const ShardVector & other) const129 bool ShardVector::operator==(const ShardVector& other) const {
130   // Check same number of shards.
131   if (this->shards.empty() && other.shards.empty()) return true;
132   if (this->shards.empty() || other.shards.empty()) return false;
133 
134   // Check number of shard dimensions match.
135   if (this->num_shards_per_dim.size() != other.num_shards_per_dim.size())
136     return false;
137 
138   // Compute lowest common multiple for each of the shard dimensions.
139   Shard first_shard_this = this->shards[0];
140   Shard first_shard_other = other.shards[0];
141   std::vector<int> new_sizes;
142   for (size_t i = 0; i < first_shard_this.size(); ++i) {
143     int lcm = this->num_shards_per_dim[i] * other.num_shards_per_dim[i] /
144               MathUtil::GCD(static_cast<unsigned>(this->num_shards_per_dim[i]),
145                             static_cast<unsigned>(other.num_shards_per_dim[i]));
146     new_sizes.push_back(lcm);
147   }
148 
149   // Expand and compare.
150   return ExpandShardVector(*this, new_sizes).shards ==
151          ExpandShardVector(other, new_sizes).shards;
152 }
153 
ToString() const154 std::string ShardVector::ToString() const {
155   std::string string = "shards:[";
156   // Convert each Shard into string.
157   std::vector<std::string> shard_strs;
158   shard_strs.reserve(shards.size());
159   for (const Shard& shard : shards)
160     shard_strs.push_back("(" + absl::StrJoin(shard, ",") + ")");
161   // Join shards, and append dimensions.
162   absl::StrAppend(&string, absl::StrJoin(shard_strs, ","));
163   absl::StrAppend(&string, "] num_shards_per_dim:(");
164   absl::StrAppend(&string, absl::StrJoin(num_shards_per_dim, ",") + ")");
165   return string;
166 }
167 
ContainsShard(const Shard & shard) const168 bool ShardVector::ContainsShard(const Shard& shard) const {
169   for (const auto& shard_in_vec : shards)
170     if (shard_in_vec == shard) return true;
171   return false;
172 }
173 
174 // static
tpu_core_ids()175 std::map<std::string, std::vector<int>>& Mesh::tpu_core_ids() {
176   static auto tpu_core_ids = new std::map<std::string, std::vector<int>>();
177   return *tpu_core_ids;
178 }
179 
180 // static
tpu_host_mesh()181 std::string& Mesh::tpu_host_mesh() {
182   static auto tpu_host_mesh = new std::string;
183   return *tpu_host_mesh;
184 }
185 
186 // static
ParseFromProto(const MeshProto & proto)187 StatusOr<Mesh> Mesh::ParseFromProto(const MeshProto& proto) {
188   Mesh mesh;
189   mesh.name_ = proto.name();
190 
191   for (const auto& device : proto.local_devices()) {
192     mesh.local_devices_.push_back(device);
193   }
194 
195   // Define local device ids.
196   for (const auto& device_id : proto.local_device_ids()) {
197     mesh.local_device_ids_.push_back(device_id);
198   }
199 
200   for (const auto& device_id : proto.global_device_ids()) {
201     mesh.global_device_ids_.push_back(device_id);
202   }
203 
204   for (const auto& device : proto.global_devices()) {
205     mesh.global_devices_.push_back(device);
206   }
207 
208   // Assign Mesh Dimensions.
209   mesh.mesh_dims_.resize(proto.mesh_dimensions_size());
210   for (int i = 0; i < proto.mesh_dimensions_size(); ++i) {
211     const MeshDimensionProto& dim = proto.mesh_dimensions(i);
212     mesh.mesh_dims_[i].name = dim.name();
213     mesh.mesh_dims_[i].size = dim.size();
214   }
215 
216   // Check invariants.
217   int64 mesh_size = mesh.size();
218   int num_devices = proto.global_device_ids_size();
219   if (mesh_size > 0 && mesh_size != num_devices) {
220     TF_RETURN_WITH_CONTEXT(
221         errors::InvalidArgument("Number of devices ", num_devices,
222                                 " not matching mesh size ", mesh_size));
223   }
224   return mesh;
225 }
226 
227 // static
GetAbstractMesh(const std::string & name,const std::vector<MeshDimension> & mesh_dims)228 StatusOr<Mesh> Mesh::GetAbstractMesh(
229     const std::string& name, const std::vector<MeshDimension>& mesh_dims) {
230   Mesh mesh;
231   mesh.name_ = name;
232   mesh.mesh_dims_ = mesh_dims;
233 
234   // Check no repeated mesh dimension names.
235   std::set<std::string> dims_set;
236   for (const MeshDimension& dim : mesh.dims()) {
237     if (dims_set.find(dim.name) != dims_set.end())
238       TF_RETURN_WITH_CONTEXT(
239           errors::InvalidArgument("repeated mesh dimension"));
240     if (dim.name == Layout::kAny || dim.name == Layout::kMatch ||
241         dim.name == Layout::kUnshardedDim)
242       TF_RETURN_WITH_CONTEXT(errors::InvalidArgument("mesh dimension name ",
243                                                      dim.name, " is reserved"));
244     dims_set.insert(dim.name);
245   }
246 
247   return mesh;
248 }
249 
250 // static
GetMesh(const std::string & name,const std::vector<MeshDimension> & mesh_dims,const std::vector<std::int64_t> & global_device_ids,const std::vector<std::int64_t> & local_device_ids,const std::vector<std::string> & local_devices,const std::vector<std::string> & global_devices)251 StatusOr<Mesh> Mesh::GetMesh(const std::string& name,
252                              const std::vector<MeshDimension>& mesh_dims,
253                              const std::vector<std::int64_t>& global_device_ids,
254                              const std::vector<std::int64_t>& local_device_ids,
255                              const std::vector<std::string>& local_devices,
256                              const std::vector<std::string>& global_devices) {
257   TF_ASSIGN_OR_RETURN(Mesh mesh, GetAbstractMesh(name, mesh_dims));
258   mesh.global_device_ids_ = global_device_ids;
259   mesh.local_device_ids_ = local_device_ids;
260   mesh.local_devices_ = local_devices;
261   mesh.global_devices_ = global_devices;
262 
263   // Check number of devices matches conditions.
264   size_t global_n = mesh.global_device_ids_.size();
265   size_t local_n = mesh.local_device_ids_.size();
266   size_t dev_n = mesh.local_devices_.size();
267 
268   if (!(global_n >= local_n && dev_n == local_n))
269     TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
270         "number of global_device_ids ", std::to_string(global_n),
271         " local_devices ids ", std::to_string(local_n), " and local devices ",
272         std::to_string(dev_n), "not meeting requirements"));
273 
274   // If empty device list, return empty mesh.
275   if (global_n == 0) return Mesh::Empty();
276 
277   if (local_n && !(global_n % local_n == 0))
278     TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
279         "Uneven local clusters with global_ids ", std::to_string(global_n),
280         " and local_devices ids ", std::to_string(local_n)));
281 
282   // Check mesh size matches number of devices.
283   if (mesh.size() != global_n)
284     TF_RETURN_WITH_CONTEXT(errors::InvalidArgument("mesh size doesn't match",
285                                                    "number of devices"));
286 
287   // Check local device invariants.
288   TF_ASSIGN_OR_RETURN(const auto& parsed_devs, mesh.ParsedDevices());
289   std::set<std::string> types_set;
290   for (const DeviceNameUtils::ParsedName& dev : parsed_devs) {
291     if (!dev.has_job || !dev.has_task || !dev.has_type)
292       return errors::InvalidArgument(
293           "Failed to either identify host or device type");
294     types_set.insert(dev.type);
295     if (types_set.size() > 1)
296       return errors::InvalidArgument(
297           "More than one device type per mesh not supported. Found ",
298           types_set.size());
299   }
300 
301   return mesh;
302 }
303 
dim_size(absl::string_view name) const304 StatusOr<int64_t> Mesh::dim_size(absl::string_view name) const {
305   for (const auto& mesh_dim : dims()) {
306     if (name == mesh_dim.name) {
307       return mesh_dim.size;
308     }
309   }
310 
311   std::vector<std::string> dim_names;
312   for (const auto& mesh_dim : dims()) dim_names.push_back(mesh_dim.name);
313 
314   return errors::NotFound(
315       "Dimension ", name, " does not exist in mesh.",
316       "Available dimensions: ", absl::StrJoin(dim_names, ","));
317 }
318 
dim_sizes() const319 std::vector<int64_t> Mesh::dim_sizes() const {
320   std::vector<int64_t> dim_sizes;
321   if (mesh_dims_.empty()) return dim_sizes;
322   for (const auto& mesh_dim : mesh_dims_) dim_sizes.push_back(mesh_dim.size);
323   return dim_sizes;
324 }
325 
operator ==(const Mesh & b) const326 bool Mesh::operator==(const Mesh& b) const {
327   return protobuf::util::MessageDifferencer::Equals(ToProto(), b.ToProto());
328 }
329 
IsEmpty() const330 bool Mesh::IsEmpty() const { return global_device_ids_.empty(); }
331 
ParsedDevices() const332 StatusOr<const std::vector<DeviceNameUtils::ParsedName>> Mesh::ParsedDevices()
333     const {
334   std::vector<DeviceNameUtils::ParsedName> parsed_devices(
335       local_devices_.size());
336   for (std::size_t i = 0; i < local_devices_.size(); ++i)
337     if (!DeviceNameUtils::ParseFullOrLocalName(
338             absl::string_view(local_devices_[i]), &parsed_devices[i]))
339       return errors::InvalidArgument("Failed to parse local_devices");
340 
341   return parsed_devices;
342 }
343 
ToDeviceType(const std::string & device_type) const344 StatusOr<Mesh> Mesh::ToDeviceType(const std::string& device_type) const {
345   std::vector<std::string> to_local_devices;
346   DeviceNameUtils::ParsedName parsed_dev;
347   for (const std::string& local_dev : local_devices_) {
348     if (!DeviceNameUtils::ParseFullOrLocalName(absl::string_view(local_dev),
349                                                &parsed_dev)) {
350       return errors::InvalidArgument("Failed to parse local devices");
351     }
352     // Converted mesh using full task name with job, replica and task ids.
353     to_local_devices.push_back(
354         DeviceNameUtils::FullName(parsed_dev.job, parsed_dev.replica,
355                                   parsed_dev.task, device_type, parsed_dev.id));
356     parsed_dev.Clear();
357   }
358   return GetMesh(name_, mesh_dims_, global_device_ids_, local_device_ids_,
359                  to_local_devices, /*global_devices=*/{});
360 }
361 
362 namespace {
HostFromParsedDev(const DeviceNameUtils::ParsedName & dev)363 std::string HostFromParsedDev(const DeviceNameUtils::ParsedName& dev) {
364   return "/job:" + dev.job + "/task:" + std::to_string(dev.task);
365 }
366 }  //  namespace
367 
hosts() const368 std::vector<std::string> Mesh::hosts() const {
369   std::vector<std::string> host_list;
370   if (IsEmpty()) return host_list;
371 
372   const auto parsed_devices = ParsedDevices().ValueOrDie();
373   for (const DeviceNameUtils::ParsedName& dev : parsed_devices) {
374     std::string host = HostFromParsedDev(dev);
375     if (std::find(host_list.begin(), host_list.end(), host) == host_list.end())
376       host_list.push_back(host);
377   }
378   return host_list;
379 }
380 
device_type() const381 std::string Mesh::device_type() const {
382   if (IsEmpty()) return std::string();
383   std::string device;
384   if (!global_devices_.empty()) {
385     device = global_devices_[0];
386   } else {
387     device = local_devices_[0];
388   }
389   DeviceNameUtils::ParsedName dev;
390   DeviceNameUtils::ParseFullOrLocalName(device, &dev);
391   return dev.type;
392 }
393 
IsMeshDim(const std::string & dim_name) const394 bool Mesh::IsMeshDim(const std::string& dim_name) const {
395   for (const auto& mesh_dim : dims())
396     if (dim_name == mesh_dim.name) return true;
397   return false;
398 }
399 
GetMeshDimIndexWithName(const std::string & mesh_name) const400 int Mesh::GetMeshDimIndexWithName(const std::string& mesh_name) const {
401   int mesh_index = -1;
402   for (int i = 0; i < dims().size(); ++i) {
403     const auto mesh_dim = dim(i);
404     if (mesh_dim.name == mesh_name) mesh_index = i;
405   }
406   assert(mesh_index >= 0);
407   return mesh_index;
408 }
409 
rank() const410 int64 Mesh::rank() const { return mesh_dims_.size(); }
411 
size() const412 int64 Mesh::size() const {
413   if (mesh_dims_.empty()) return 0;
414 
415   int64 size = 1;
416   for (const MeshDimension& dim : mesh_dims_) size *= dim.size;
417   return size;
418 }
419 
Empty()420 Mesh Mesh::Empty() { return Mesh(); }
421 
ToProto() const422 MeshProto Mesh::ToProto() const {
423   MeshProto mesh_proto;
424   mesh_proto.set_name(name());
425 
426   for (const auto& d : local_devices_) {
427     mesh_proto.add_local_devices(d);
428   }
429 
430   for (const auto& i : local_device_ids_) {
431     mesh_proto.add_local_device_ids(i);
432   }
433 
434   for (const auto& i : global_device_ids_) {
435     mesh_proto.add_global_device_ids(i);
436   }
437 
438   for (const auto& dim : mesh_dims_) {
439     MeshDimensionProto* mesh_dim_proto = mesh_proto.add_mesh_dimensions();
440     mesh_dim_proto->set_name(dim.name);
441     mesh_dim_proto->set_size(dim.size);
442   }
443 
444   for (const auto& d : global_devices_) {
445     mesh_proto.add_global_devices(d);
446   }
447   return mesh_proto;
448 }
449 
ToString() const450 std::string Mesh::ToString() const {
451   if (Mesh::IsEmpty()) return kEmptyMeshString;
452 
453   // We use "|" to separate name, mesh dimensions and devices.
454   std::string mesh_str = absl::StrCat(Mesh::name(), "|");
455 
456   // Add mesh dimensions
457   absl::InlinedVector<std::string, 4> mesh_dim_lst;
458   for (const auto& dim : mesh_dims_)
459     mesh_dim_lst.push_back(absl::StrCat(dim.name, "=", dim.size));
460   mesh_str += absl::StrJoin(mesh_dim_lst, ",") + "|";
461 
462   // Add flattened list of global device ids
463   mesh_str += absl::StrJoin(global_device_ids_, ",") + "|";
464 
465   // Add flattened list of local device ids
466   mesh_str += absl::StrJoin(local_device_ids_, ",") + "|";
467 
468   // Add flattened list of local devices
469   mesh_str += absl::StrJoin(local_devices_, ",");
470 
471   if (!global_devices_.empty()) {
472     // Add flattened list of global devices
473     mesh_str += "|";
474     mesh_str += absl::StrJoin(global_devices_, ",");
475   }
476   return mesh_str;
477 }
478 
GlobalFingerprint() const479 uint64 Mesh::GlobalFingerprint() const {
480   if (Mesh::IsEmpty()) return Fingerprint64(kEmptyMeshString);
481 
482   std::string mesh_str;
483   // Add mesh dimensions
484   absl::InlinedVector<std::string, 4> mesh_dim_lst;
485   for (const auto& dim : mesh_dims_)
486     mesh_dim_lst.push_back(absl::StrCat(dim.name, "=", dim.size));
487   mesh_str += absl::StrJoin(mesh_dim_lst, ",") + "|";
488 
489   // Ignore local_device_ids_, local_devices and name which might be not global
490   // unique.
491   // Add flattened list of global device ids
492   mesh_str += absl::StrJoin(global_device_ids_, ",") + "|";
493 
494   if (!global_devices_.empty()) {
495     // Add flattened list of global devices
496     mesh_str += "|";
497     mesh_str += absl::StrJoin(global_devices_, ",");
498   }
499   // mesh dims | global device ids (| global devices)
500   return Fingerprint64(mesh_str);
501 }
502 
503 namespace {
StrToMeshDimension(const std::string & str)504 MeshDimension StrToMeshDimension(const std::string& str) {
505   MeshDimension mesh_dim;
506   if (str.empty()) return mesh_dim;
507 
508   std::vector<std::string> mesh_dim_parts = absl::StrSplit(str, '=');
509 
510   mesh_dim.name = mesh_dim_parts[0];
511   mesh_dim.size = std::stoi(mesh_dim_parts[1]);
512   return mesh_dim;
513 }
514 
GenerateMeshDevicesForTests(const std::string & name,const std::vector<MeshDimension> & mesh_dims,const std::string & mesh_gen_instruction)515 StatusOr<Mesh> GenerateMeshDevicesForTests(
516     const std::string& name, const std::vector<MeshDimension>& mesh_dims,
517     const std::string& mesh_gen_instruction) {
518   // Parse mesh generation instruction.
519   std::vector<std::string> instruction_parts =
520       absl::StrSplit(mesh_gen_instruction, '*');
521   if (instruction_parts.size() != 2)
522     TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
523         "Expected a * in mesh_gen_instructions but found ",
524         mesh_gen_instruction));
525   std::string device_type = instruction_parts[1];
526 
527   // Get Mesh Size.
528   int64 mesh_size = 0;
529   if (!mesh_dims.empty()) {
530     mesh_size = 1;
531     for (const MeshDimension& mesh_dim : mesh_dims) mesh_size *= mesh_dim.size;
532   }
533 
534   // Generate device ids.
535   std::vector<int64_t> global_device_ids;
536   std::vector<int64_t> local_device_ids;
537   std::vector<std::string> local_devices;
538   for (std::size_t i = 0; i < mesh_size; ++i) {
539     global_device_ids.push_back(i);
540     local_device_ids.push_back(i);
541     local_devices.push_back("/job:localhost/task:0/device:" + device_type +
542                             ":" + std::to_string(i));
543   }
544 
545   TF_ASSIGN_OR_RETURN(
546       Mesh mesh,
547       Mesh::GetMesh(name, mesh_dims, global_device_ids, local_device_ids,
548                     local_devices, /*global_devices=*/{}));
549   return mesh;
550 }
551 }  // namespace
552 
553 // static
FromString(const std::string & str)554 StatusOr<Mesh> Mesh::FromString(const std::string& str) {
555   if (str == kEmptyMeshString) return Mesh::Empty();
556 
557   std::vector<std::string> mesh_parts = absl::StrSplit(str, '|');
558 
559   // Check formatting error.
560   if (mesh_parts.size() != 3 && mesh_parts.size() != 5 &&
561       mesh_parts.size() != 6)
562     TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
563         "Expected either 5, 6 or 3 mesh parts but found", mesh_parts.size()));
564 
565   // Populate mesh.
566   std::string name = mesh_parts[0];
567 
568   // Add mesh dimensions.
569   std::vector<MeshDimension> mesh_dims;
570   if (!mesh_parts[1].empty()) {
571     std::vector<std::string> mesh_dim_strs = absl::StrSplit(mesh_parts[1], ',');
572     mesh_dims.reserve(mesh_dim_strs.size());
573     for (const std::string& mesh_dim_str : mesh_dim_strs)
574       mesh_dims.push_back(StrToMeshDimension(mesh_dim_str));
575   }
576 
577   // Check if mesh is set to be autogenerated.
578   if (mesh_parts.size() == 3)
579     return GenerateMeshDevicesForTests(name, mesh_dims, mesh_parts[2]);
580 
581   // Add global device ids list.
582   std::vector<int64_t> global_device_ids;
583   if (!mesh_parts[2].empty()) {
584     std::vector<std::string> global_device_ids_strs =
585         absl::StrSplit(mesh_parts[2], ',');
586 
587     global_device_ids.reserve(global_device_ids_strs.size());
588     for (const std::string& id : global_device_ids_strs)
589       global_device_ids.push_back(std::stoi(id));
590   }
591 
592   // Add local device ids list.
593   std::vector<int64_t> local_device_ids;
594   if (!mesh_parts[3].empty()) {
595     std::vector<std::string> local_device_ids_strs =
596         absl::StrSplit(mesh_parts[3], ',');
597 
598     local_device_ids.reserve(local_device_ids_strs.size());
599     for (const std::string& id : local_device_ids_strs)
600       local_device_ids.push_back(std::stoi(id));
601   }
602   // Add local devices.
603   std::vector<std::string> local_devices;
604   if (!mesh_parts[4].empty())
605     local_devices = absl::StrSplit(mesh_parts[4], ',');
606 
607   std::vector<std::string> global_devices;
608   if (mesh_parts.size() == 6) {
609     // Add global devices.
610     if (!mesh_parts[5].empty())
611       global_devices = absl::StrSplit(mesh_parts[5], ',');
612   }
613 
614   TF_ASSIGN_OR_RETURN(
615       Mesh mesh,
616       Mesh::GetMesh(name, mesh_dims, global_device_ids, local_device_ids,
617                     local_devices, global_devices));
618   return mesh;
619 }
620 
num_devices() const621 int64 Mesh::num_devices() const { return global_device_ids_.size(); }
622 
device_location(int offset) const623 StatusOr<const DeviceLocation> Mesh::device_location(int offset) const {
624   if (offset < 0 || offset > size() - 1)
625     return errors::InvalidArgument(
626         "Mesh offset cannot be negative or exceed Mesh's size. Offset size:",
627         offset, " and Mesh size:", size());
628 
629   DeviceLocation dev_loc;
630   std::vector<int64> mesh_dim_lengths = dim_sizes();
631   int64 i = mesh_dim_lengths.size() - 1;
632   while (i >= 0) {
633     dev_loc.insert(dev_loc.begin(), offset % mesh_dim_lengths[i]);
634     offset /= mesh_dim_lengths[i];
635     --i;
636   }
637   return dev_loc;
638 }
639 
GetFlattenedCoordinate(const DeviceLocation & loc) const640 int64 Mesh::GetFlattenedCoordinate(const DeviceLocation& loc) const {
641   const std::vector<int64> mesh_dim_sizes = dim_sizes();
642   int64 i = mesh_dim_sizes.size() - 1;
643   int64 acc = 1;
644   int64 device_pos = 0;
645   while (i >= 0) {
646     device_pos += loc[i] * acc;
647     acc *= mesh_dim_sizes[i];
648     --i;
649   }
650   return device_pos;
651 }
652 
idx_for_dim(absl::string_view dim_name) const653 StatusOr<int32> Mesh::idx_for_dim(absl::string_view dim_name) const {
654   for (int i = 0; i < mesh_dims_.size(); ++i) {
655     if (mesh_dims_[i].name == dim_name) return i;
656   }
657   return errors::InvalidArgument("dim name :", dim_name,
658                                  " does not exist on mesh : ", ToString());
659 }
660 
GetLayout(const std::vector<std::string> & sharding_spec_strs,const Mesh & mesh)661 StatusOr<Layout> Layout::GetLayout(
662     const std::vector<std::string>& sharding_spec_strs, const Mesh& mesh) {
663   // Re-format sharding specs.
664   std::vector<ShardingSpec> sharding_specs;
665   sharding_specs.reserve(sharding_spec_strs.size());
666   for (const std::string& spec_str : sharding_spec_strs) {
667     ShardingSpec spec;
668     spec.set_sharding_spec(spec_str);
669     sharding_specs.push_back(spec);
670   }
671   return GetLayout(sharding_specs, mesh);
672 }
673 
GetLayout(const std::vector<ShardingSpec> & sharding_specs,const Mesh & mesh)674 StatusOr<Layout> Layout::GetLayout(
675     const std::vector<ShardingSpec>& sharding_specs, const Mesh& mesh) {
676   Layout layout;
677   // Append mesh, then check sharding_specs are legal.
678   layout.mesh_ = mesh;
679 
680   // Check sharding_specs are either mesh dimension or special value.
681   for (const auto& dim : sharding_specs) {
682     const std::string& sharding_spec = dim.sharding_spec();
683     if (!(sharding_spec == kUnshardedDim || sharding_spec == kAny ||
684           sharding_spec == kMatch || mesh.IsMeshDim(sharding_spec) ||
685           sharding_spec == "scalar"))
686       TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
687           "sharding spec (", sharding_spec,
688           ") refers to mesh dimension not contained in mesh ",
689           mesh.ToString()));
690   }
691   // Check same tensor dimensions not sharded over same mesh dimension twice.
692   std::set<std::string> dims_set;
693   for (const auto& dim : sharding_specs) {
694     const std::string& sharding_spec = dim.sharding_spec();
695     if (sharding_spec == kUnshardedDim || sharding_spec == kAny) continue;
696     // If scalar, delete all sharding specs.
697     if (sharding_spec == "scalar") {
698       if (sharding_specs.size() > 1)
699         TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
700             "A scalar sharding_spec can only be used as a single sharding_spec "
701             "instruction, not as part of list of sharding_specs as attempted "
702             "here with ",
703             sharding_specs.size(), " sharding_specs"))
704       // Return layout with empty spec to represent scalar behavior.
705       return layout;
706     }
707     if (dims_set.find(sharding_spec) != dims_set.end())
708       TF_RETURN_WITH_CONTEXT(
709           errors::InvalidArgument("Attempted to shard two or more tensor "
710                                   "dimensions over mesh dimension ",
711                                   sharding_spec))
712     dims_set.insert(sharding_spec);
713   }
714   // After checking sharding_specs are legal, append and return layout.
715   layout.sharding_specs_ = sharding_specs;
716   return layout;
717 }
718 
Empty()719 Layout Layout::Empty() {
720   Layout result;
721   return result;
722 }
723 
IsEmpty() const724 bool Layout::IsEmpty() const { return mesh_.IsEmpty(); }
725 
726 namespace {
ReducedAbstractMesh(const Layout * layout)727 Mesh ReducedAbstractMesh(const Layout* layout) {
728   const std::vector<std::string>& shard_spec_strs =
729       layout->sharding_spec_strs();
730   std::vector<MeshDimension> reduced_mesh_dims;
731   reduced_mesh_dims.reserve(layout->mesh().dims().size());
732   for (const MeshDimension& mesh_dim : layout->mesh().dims()) {
733     bool IsMeshDimInShardingSpecs =
734         std::find(shard_spec_strs.begin(), shard_spec_strs.end(),
735                   mesh_dim.name) != shard_spec_strs.end();
736     // If dimension not in sharding_spec, flip size to 1.
737     MeshDimension reduced_dim =
738         IsMeshDimInShardingSpecs ? mesh_dim : MeshDimension(mesh_dim.name, 1);
739     reduced_mesh_dims.push_back(reduced_dim);
740   }
741   return Mesh::GetAbstractMesh("", reduced_mesh_dims).ValueOrDie();
742 }
743 
744 }  // namespace
745 
ReducedMesh() const746 Mesh Layout::ReducedMesh() const {
747   // Set replicated mesh dimensions to size 1, and create reduced abstract mesh.
748   Mesh reduced_mesh = ReducedAbstractMesh(this);
749 
750   // Populate reduced mesh with global devices from original mesh.
751   std::vector<int64_t> reduced_global_device_ids;
752   std::vector<std::string> reduced_global_devs;
753   for (const DeviceLocation& loc : ComputeDeviceLocations(&reduced_mesh)) {
754     int64 pos = mesh().GetFlattenedCoordinate(loc);
755     reduced_global_device_ids.push_back(mesh().global_device_ids().at(pos));
756     if (!mesh().global_devices().empty()) {
757       reduced_global_devs.push_back(mesh().global_devices().at(pos));
758     }
759   }
760 
761   // Track the set of global device IDs in the abstract mesh.
762   std::set<int64_t> reduced_global_device_ids_set(
763       reduced_global_device_ids.begin(), reduced_global_device_ids.end());
764 
765   // Populate reduced mesh with local devices in the same order as the original
766   // mesh.
767   std::vector<int64_t> reduced_local_device_ids;
768   std::vector<std::string> reduced_local_devs;
769   for (size_t i = 0; i < mesh().local_device_ids().size(); ++i) {
770     int64_t device_id = mesh().local_device_ids().at(i);
771     if (reduced_global_device_ids_set.find(device_id) !=
772         reduced_global_device_ids_set.end()) {
773       reduced_local_device_ids.push_back(device_id);
774       reduced_local_devs.push_back(mesh().local_devices().at(i));
775     }
776   }
777 
778   return Mesh::GetMesh(reduced_mesh.name(), reduced_mesh.dims(),
779                        reduced_global_device_ids, reduced_local_device_ids,
780                        reduced_local_devs, reduced_global_devs)
781       .ValueOrDie();
782 }
783 
784 namespace {
ReducedLayout(const Layout * layout)785 Layout ReducedLayout(const Layout* layout) {
786   // Change format sharding specs.
787   std::vector<ShardingSpec> shard_specs(layout->sharding_specs().size());
788   for (size_t i = 0; i < shard_specs.size(); ++i)
789     shard_specs[i] = layout->dim(i);
790   // Retrieve layout.
791   return Layout::GetLayout(shard_specs, layout->ReducedMesh()).ValueOrDie();
792 }
793 
794 // Returns index of the given mesh dimension or mesh dim size if not found.
IndexOfMeshDimension(const Mesh & mesh,const std::string & dim_name)795 StatusOr<int> IndexOfMeshDimension(const Mesh& mesh,
796                                    const std::string& dim_name) {
797   for (size_t i = 0; i < mesh.dims().size(); ++i)
798     if (dim_name == mesh.dims()[i].name) return i;
799   return errors::InvalidArgument("Mesh dimension not found");
800 }
801 }  // namespace
802 
GetShardVector() const803 ShardVector Layout::GetShardVector() const {
804   // Change format sharding specs.
805   std::vector<ShardingSpec> shard_specs(sharding_specs().size());
806   for (size_t i = 0; i < shard_specs.size(); ++i) shard_specs[i] = dim(i);
807   // Obtain a shard position (i.e. sharded section of a tensor) from a mesh
808   // location, using the sharding specs.
809   auto GetShardFromDeviceLocation = [&](const DeviceLocation& loc) -> Shard {
810     Shard shard;
811     for (size_t i = 0; i < shard_specs.size(); ++i) {
812       // If unsharded, there is only one shard, that is 1.
813       std::string spec = shard_specs[i].sharding_spec();
814       if (spec == Layout::kUnshardedDim) {
815         shard.push_back(1);
816       } else {
817         int mesh_index =
818             IndexOfMeshDimension(mesh(), sharding_spec(i)).ValueOrDie();
819         int shard_number = loc[mesh_index] + 1;
820         shard.push_back(shard_number);
821       }
822     }
823     return shard;
824   };
825   // Obtain dims of shard vector.
826   auto ShardVectorDims = [&]() -> std::vector<int> {
827     std::vector<int> num_shards_per_dim(shard_specs.size());
828     for (size_t i = 0; i < sharding_specs().size(); ++i) {
829       ShardingSpec spec = sharding_specs()[i];
830       if (Layout::IsShardedSpec(spec)) {
831         StatusOr<int64> dim_size = mesh().dim_size(spec.sharding_spec());
832         num_shards_per_dim[i] = dim_size.ValueOrDie();
833       } else {
834         num_shards_per_dim[i] = 1;
835       }
836     }
837     return num_shards_per_dim;
838   };
839   // Compute mesh locations and obtain shards from them.
840   ShardVector shard_vec;
841   for (const DeviceLocation& mesh_loc : ComputeDeviceLocations(&mesh()))
842     shard_vec.shards.push_back(GetShardFromDeviceLocation(mesh_loc));
843   // Calculate dims.
844   shard_vec.num_shards_per_dim = ShardVectorDims();
845   return shard_vec;
846 }
847 
HostShardMap() const848 std::map<std::string, ShardVector> Layout::HostShardMap() const {
849   Layout reduced_layout = ReducedLayout(this);
850   Mesh reduced_mesh = reduced_layout.mesh();
851   using HostName = std::string;
852 
853   // Build a map: {Host : Shards}
854   std::map<HostName, ShardVector> host_shards_map;
855   ShardVector shard_vec_in_red_layout = reduced_layout.GetShardVector();
856 
857   const auto parsed_devs = reduced_mesh.ParsedDevices().ValueOrDie();
858   for (size_t i = 0; i < parsed_devs.size(); ++i) {
859     HostName host = HostFromParsedDev(parsed_devs[i]);
860     Shard shard_in_device = shard_vec_in_red_layout.shards[i];
861 
862     // Check if host in hashtable and append shard.
863     auto it = host_shards_map.find(host);
864     if (it == host_shards_map.end()) {
865       ShardVector shard_vec_in_host;
866       shard_vec_in_host.shards.push_back(shard_in_device);
867       shard_vec_in_host.num_shards_per_dim =
868           shard_vec_in_red_layout.num_shards_per_dim;
869       host_shards_map.insert(
870           std::pair<HostName, ShardVector>(host, shard_vec_in_host));
871     } else {
872       bool isShardInShardVector = it->second.ContainsShard(shard_in_device);
873       if (!isShardInShardVector) {
874         it->second.shards.push_back(shard_in_device);
875       }
876     }
877   }
878   // Sort shards inside each host.
879   for (auto it = host_shards_map.begin(); it != host_shards_map.end(); ++it) {
880     std::sort(it->second.shards.begin(), it->second.shards.end());
881   }
882   return host_shards_map;
883 }
884 
sharding_spec(int idx) const885 const std::string& Layout::sharding_spec(int idx) const {
886   return sharding_specs_[idx].sharding_spec();
887 }
888 
num_shards() const889 std::vector<int32> Layout::num_shards() const {
890   std::vector<int32> num_shards;
891   num_shards.reserve(sharding_specs_.size());
892   for (const auto& sharding_spec : sharding_specs_) {
893     num_shards.push_back(num_shards_for_dim(sharding_spec));
894   }
895   return num_shards;
896 }
897 
num_shards_for_dim(const ShardingSpec & dim) const898 size_t Layout::num_shards_for_dim(const ShardingSpec& dim) const {
899   absl::string_view name = dim.sharding_spec();
900   if (name == Layout::kUnshardedDim) return 1;
901   if (name == Layout::kMatch) return -1;
902 
903   return mesh().dim_size(name).ValueOrDie();
904 }
905 
IsFullyReplicated() const906 bool Layout::IsFullyReplicated() const {
907   for (const auto& sharding_spec : sharding_specs_) {
908     if (num_shards_for_dim(sharding_spec) > 1) {
909       return false;
910     }
911   }
912   return true;
913 }
914 
IsLastDimReplicated() const915 bool Layout::IsLastDimReplicated() const {
916   return (sharding_specs_.empty()) ||
917          (num_shards_for_dim(sharding_specs_.back()) == 1);
918 }
919 
IsBatchParallel() const920 bool Layout::IsBatchParallel() const {
921   if (sharding_specs_.empty()) {
922     return true;
923   }
924 
925   for (int i = 1; i < sharding_specs_.size(); ++i) {
926     const auto& dim = sharding_specs_[i];
927     if (num_shards_for_dim(dim) != 1) {
928       return false;
929     }
930   }
931   return true;
932 }
933 
934 // TODO(samuelslee) Replace this with the IsBatchParallel() everywhere
IsBatchParallel(int non_batch_rank) const935 bool Layout::IsBatchParallel(int non_batch_rank) const {
936   if (sharding_specs_.empty()) return true;
937   for (int i = rank() - non_batch_rank; i < rank(); ++i) {
938     if (num_shards_for_dim(sharding_specs_[i]) != 1) return false;
939   }
940   return true;
941 }
942 
ToProto() const943 LayoutProto Layout::ToProto() const {
944   LayoutProto proto;
945   *proto.mutable_mesh_config() = mesh_.ToProto();
946   for (const auto& dim : sharding_specs_) {
947     *proto.add_sharding_specs() = dim;
948   }
949   return proto;
950 }
951 
IsEquivalent(const Layout & b) const952 bool Layout::IsEquivalent(const Layout& b) const {
953   if (this->rank() != b.rank()) return false;
954   if (this->mesh() != b.mesh()) return false;
955   for (int i = 0; i < this->rank(); ++i) {
956     if (this->sharding_specs_[i].sharding_spec() !=
957         b.sharding_specs_[i].sharding_spec()) {
958       if ((this->num_shards_for_dim(this->sharding_specs_[i]) != 1) ||
959           (b.num_shards_for_dim(b.sharding_specs_[i]) != 1))
960         return false;
961     }
962   }
963   return true;
964 }
965 
operator ==(const Layout & b) const966 bool Layout::operator==(const Layout& b) const {
967   return protobuf::util::MessageDifferencer::Equals(ToProto(), b.ToProto());
968 }
969 
GlobalShapeFromLocalShape(const std::vector<int64_t> & local_shape) const970 std::vector<int64_t> Layout::GlobalShapeFromLocalShape(
971     const std::vector<int64_t>& local_shape) const {
972   if (IsFullyReplicated()) {
973     return local_shape;
974   }
975   std::vector<int64_t> global_shape;
976   global_shape.reserve(sharding_specs().size());
977   for (int i = 0; i < sharding_specs().size(); ++i) {
978     int64_t l_shape = local_shape.empty() ? 1 : local_shape[i];
979     int64_t dim_shards = num_shards()[i];
980     global_shape.emplace_back(l_shape * dim_shards);
981   }
982   return global_shape;
983 }
984 
LocalShapeFromGlobalShape(absl::Span<const int64_t> global_shape) const985 std::vector<int64_t> Layout::LocalShapeFromGlobalShape(
986     absl::Span<const int64_t> global_shape) const {
987   if (IsFullyReplicated()) {
988     return std::vector<int64_t>(global_shape.begin(), global_shape.end());
989   }
990   std::vector<int32> shards = num_shards();
991   std::vector<int64_t> local_shape;
992   for (int i = 0; i < sharding_specs().size(); ++i) {
993     int64_t dim_shards = shards[i];
994     // TODO(hthu): Shape might not be always divisible.
995     local_shape.emplace_back(global_shape[i] / dim_shards);
996   }
997   return local_shape;
998 }
999 
LocalShapeFromGlobalShape(const PartialTensorShape & global_shape) const1000 PartialTensorShape Layout::LocalShapeFromGlobalShape(
1001     const PartialTensorShape& global_shape) const {
1002   if (IsFullyReplicated() || global_shape.dims() == -1) {
1003     return global_shape;
1004   }
1005   std::vector<int32> shards = num_shards();
1006   PartialTensorShape local_shape({});
1007   for (int spec_index = 0; spec_index < sharding_specs().size(); ++spec_index) {
1008     int64_t dim_size = global_shape.dim_size(spec_index);
1009     local_shape.AddDim(dim_size == -1 ? -1 : dim_size / shards[spec_index]);
1010   }
1011   return local_shape;
1012 }
1013 
FromProto(const LayoutProto & proto)1014 StatusOr<Layout> Layout::FromProto(const LayoutProto& proto) {
1015   Layout layout;
1016   for (const auto& spec : proto.sharding_specs())
1017     layout.sharding_specs_.push_back(spec);
1018 
1019   TF_ASSIGN_OR_RETURN(auto mesh, Mesh::ParseFromProto(proto.mesh_config()));
1020   layout.mesh_ = std::move(mesh);
1021 
1022   return GetLayout(layout.sharding_specs_, layout.mesh_);
1023 }
1024 
ReplicatedOnMesh(const Mesh & mesh,int rank)1025 Layout Layout::ReplicatedOnMesh(const Mesh& mesh, int rank) {
1026   std::vector<std::string> specs(rank, kUnshardedDim);
1027   return Layout::GetLayout(specs, mesh).ValueOrDie();
1028 }
1029 
AnyOnMesh(const Mesh & mesh,int rank)1030 Layout Layout::AnyOnMesh(const Mesh& mesh, int rank) {
1031   std::vector<std::string> specs(rank, kAny);
1032   return Layout::GetLayout(specs, mesh).ValueOrDie();
1033 }
1034 
Transposed2D(const Layout & layout)1035 StatusOr<Layout> Layout::Transposed2D(const Layout& layout) {
1036   if (layout.rank() < 2) {
1037     return errors::InvalidArgument("Transposed2D requires rank to be >= 2");
1038   }
1039   std::vector<std::string> transposed_specs = layout.sharding_spec_strs();
1040   std::iter_swap(transposed_specs.end() - 2, transposed_specs.end() - 1);
1041   return Layout::GetLayout(transposed_specs, layout.mesh()).ValueOrDie();
1042 }
1043 
1044 // static
FromString(std::string layout_str)1045 StatusOr<Layout> Layout::FromString(std::string layout_str) {
1046   if (layout_str == kEmptyLayoutString) return Layout::Empty();
1047 
1048   // Print sharding specs.
1049   std::vector<absl::string_view> layout_parts = absl::StrSplit(layout_str, ' ');
1050   // Check formatting error.
1051   if (layout_parts.size() != 2) {
1052     TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
1053         "Expected 2 items but found ", layout_parts.size(), layout_parts[0]));
1054   }
1055   // Substract prefixes.
1056   absl::string_view sharding_spec_str = layout_parts[0];
1057   absl::ConsumePrefix(&sharding_spec_str, "sharding_specs:");
1058 
1059   absl::string_view mesh_str = layout_parts[1];
1060   absl::ConsumePrefix(&mesh_str, "mesh:");
1061 
1062   // Add sharding specs.
1063   std::vector<std::string> sharding_spec_strs =
1064       absl::StrSplit(sharding_spec_str, ',');
1065   sharding_spec_strs.pop_back();
1066 
1067   // Add mesh.
1068   TF_ASSIGN_OR_RETURN(Mesh mesh, Mesh::FromString(string(mesh_str)));
1069   // Try to create layout.
1070   TF_ASSIGN_OR_RETURN(Layout layout,
1071                       Layout::GetLayout(sharding_spec_strs, mesh));
1072   return layout;
1073 }
1074 
sharding_spec_strs() const1075 std::vector<std::string> Layout::sharding_spec_strs() const {
1076   std::vector<std::string> sharding_spec_strs(sharding_specs().size());
1077   for (size_t i = 0; i < sharding_specs().size(); ++i)
1078     sharding_spec_strs[i] = sharding_spec(i);
1079   return sharding_spec_strs;
1080 }
1081 
ToString() const1082 std::string Layout::ToString() const {
1083   if (Layout::IsEmpty()) return kEmptyLayoutString;
1084 
1085   std::string layout_str = "sharding_specs:";
1086   // Print sharding specs.
1087   for (const ShardingSpec& dim : sharding_specs_) {
1088     std::string dim_name = dim.sharding_spec();
1089     absl::StrAppend(&layout_str, dim_name + ",");
1090   }
1091   // Append mesh.
1092   absl::StrAppend(&layout_str, " mesh:", mesh_.ToString());
1093   return layout_str;
1094 }
1095 
GetLayoutWithReducedDims(const absl::flat_hash_set<int> & reduced_dims,bool keep_dims) const1096 Layout Layout::GetLayoutWithReducedDims(
1097     const absl::flat_hash_set<int>& reduced_dims, bool keep_dims) const {
1098   dtensor::LayoutProto output_layout;
1099   *output_layout.mutable_mesh_config() = mesh().ToProto();
1100 
1101   for (int i = 0; i < rank(); ++i) {
1102     // reduced_dims may contain negative values.
1103     if (!reduced_dims.contains(i) && !reduced_dims.contains(i - rank())) {
1104       *output_layout.add_sharding_specs() = dim(i);
1105     } else if (keep_dims) {
1106       auto* replicated_dim = output_layout.add_sharding_specs();
1107       replicated_dim->set_sharding_spec(kUnshardedDim);
1108     }
1109   }
1110   return Layout::FromProto(output_layout).ValueOrDie();
1111 }
1112 
Truncate(int64 split_point,bool end) const1113 Layout Layout::Truncate(int64 split_point, bool end) const {
1114   if ((split_point == 0 && end) || (split_point == rank() && !end))
1115     return *this;
1116 
1117   dtensor::LayoutProto output_layout;
1118   *output_layout.mutable_mesh_config() = mesh().ToProto();
1119 
1120   if (end) {
1121     for (int i = split_point; i < rank(); ++i)
1122       *output_layout.add_sharding_specs() = dim(i);
1123   } else {
1124     for (int i = 0; i < split_point; ++i)
1125       *output_layout.add_sharding_specs() = dim(i);
1126   }
1127   return Layout::FromProto(output_layout).ValueOrDie();
1128 }
1129 
1130 namespace {
1131 // Adds unsharded sharding specs to layout.
PadLayout(const int64 rank,const bool is_padding_before,const Layout & layout)1132 Layout PadLayout(const int64 rank, const bool is_padding_before,
1133                  const Layout& layout) {
1134   if (rank <= layout.rank()) return layout;
1135 
1136   // Create list of padding sharding specs.
1137   const int n = rank - layout.rank();
1138   std::vector<ShardingSpec> new_specs(n);
1139   for (int i = 0; i < n; ++i)
1140     new_specs[i].set_sharding_spec(Layout::kUnshardedDim);
1141 
1142   // Define concatenation point of layout specs.
1143   auto concat_point = is_padding_before ? new_specs.end() : new_specs.begin();
1144 
1145   // Concatenate old layout specs and new unsharded specs.
1146   new_specs.insert(concat_point, layout.sharding_specs().begin(),
1147                    layout.sharding_specs().end());
1148   return Layout::GetLayout(new_specs, layout.mesh()).ValueOrDie();
1149 }
1150 }  // namespace
1151 
LeftPad(int64 rank) const1152 Layout Layout::LeftPad(int64 rank) const {
1153   bool is_padding_before = true;
1154   return PadLayout(rank, is_padding_before, *this);
1155 }
1156 
ConcatenateLayouts(const Layout & layout_a,const Layout & layout_b)1157 StatusOr<Layout> ConcatenateLayouts(const Layout& layout_a,
1158                                     const Layout& layout_b) {
1159   if (layout_a.mesh() != layout_b.mesh())
1160     return errors::InvalidArgument(
1161         "unable to concatenate layouts as they are on different meshes.");
1162 
1163   absl::flat_hash_set<std::string> layout_a_mesh_dims;
1164   for (int i = 0; i < layout_a.rank(); ++i)
1165     if (layout_a.sharding_spec(i) != Layout::kUnshardedDim)
1166       layout_a_mesh_dims.emplace(layout_a.sharding_spec(i));
1167 
1168   for (int i = 0; i < layout_b.rank(); ++i)
1169     if (layout_b.sharding_spec(i) != Layout::kUnshardedDim &&
1170         layout_a_mesh_dims.contains(layout_b.sharding_spec(i)))
1171       return errors::InvalidArgument(
1172           "unable to concatenate layouts as they use the same meshes "
1173           "dimension: ",
1174           layout_b.sharding_spec(i), " is used in both layouts.");
1175 
1176   LayoutProto layout_proto_a = layout_a.ToProto();
1177   LayoutProto layout_proto_b = layout_b.ToProto();
1178   LayoutProto output_layout_proto;
1179 
1180   *output_layout_proto.mutable_mesh_config() = layout_proto_a.mesh_config();
1181   for (int i = 0; i < layout_proto_a.sharding_specs_size(); ++i)
1182     *output_layout_proto.add_sharding_specs() =
1183         layout_proto_a.sharding_specs(i);
1184   for (int i = 0; i < layout_proto_b.sharding_specs_size(); ++i)
1185     *output_layout_proto.add_sharding_specs() =
1186         layout_proto_b.sharding_specs(i);
1187   return Layout::FromProto(output_layout_proto);
1188 }
1189 
1190 }  // namespace dtensor
1191 }  // namespace tensorflow
1192