1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/cc/tensor_layout.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <numeric>
23 #include <set>
24 #include <string>
25 #include <string_view>
26 #include <utility>
27 #include <vector>
28
29 #include "absl/container/inlined_vector.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "absl/strings/string_view.h"
34 #include "absl/types/optional.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/lib/math/math_util.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/fingerprint.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/statusor.h"
41 #include "tensorflow/core/util/device_name_utils.h"
42 #include "tensorflow/dtensor/cc/dstatus.h"
43 #include "tensorflow/dtensor/proto/layout.pb.h"
44
45 namespace tensorflow {
46 namespace dtensor {
47
48 constexpr const char* Layout::kUnshardedDim;
49 constexpr const char* Layout::kAny;
50 constexpr const char* Layout::kEmptyLayoutString;
51 constexpr const char* Layout::kMatch;
52 constexpr const char* Mesh::kEmptyMeshString;
53
54 namespace {
55 // Obtain all possible forms of indexing a mesh.
56 //
57 // e.g. given a mesh with dimensions [x=2, y=3], returns {
58 // [0, 0], [0, 1], [0, 2],
59 // [1, 0], [1, 1], [1, 2]
60 // }
ComputeDeviceLocations(const Mesh * mesh)61 inline std::vector<DeviceLocation> ComputeDeviceLocations(const Mesh* mesh) {
62 std::vector<DeviceLocation> mesh_locs(mesh->size());
63 for (size_t i = 0; i < mesh->size(); ++i)
64 mesh_locs[i] = *(mesh->device_location(i));
65 return mesh_locs;
66 }
67 } // namespace
68
69 namespace {
70 // Expands a ShardVector into the size defined in new_num_shards_per_dim.
71 //
72 // For example, the inputs:
73 // - shard_vec: shards = [(1,1)] num_shards_per_dim = [1,1]
74 // - new_num_shards_per_dim = [2,2]
75 //
76 // Would lead to:
77 // shard_vec: shards = [(1,1),(1,2),(2,1),(2,2)] num_shards_per_dim = [2,2]
78 //
79 // This is used to check whether two ShardVectors contain the same information
80 // while having different number of shards per dimension. The two ShardVectors
81 // above are an example of this.
ExpandShardVector(const ShardVector & shard_vec,const std::vector<int> & new_num_shards_per_dim)82 ShardVector ExpandShardVector(const ShardVector& shard_vec,
83 const std::vector<int>& new_num_shards_per_dim) {
84 if (shard_vec.shards.empty()) return shard_vec;
85
86 // Takes a single shard and expands it into multiple shards.
87 auto ExpandShard = [shard_vec, new_num_shards_per_dim](
88 const Shard& shard,
89 int dim_ind) -> std::vector<Shard> {
90 int original_dim_size = shard_vec.num_shards_per_dim[dim_ind];
91 int new_dim_size = new_num_shards_per_dim[dim_ind];
92 int size_ratio = new_dim_size / original_dim_size;
93
94 std::vector<Shard> expanded_shards;
95 expanded_shards.reserve(size_ratio);
96 for (int i = 0; i < size_ratio; ++i) {
97 int original_coord = shard[dim_ind];
98 int shifted_coord = (original_coord - 1) * size_ratio + 1 + i;
99 // Copy original shard, then modify it.
100 Shard new_shard = shard;
101 new_shard[dim_ind] = shifted_coord;
102 expanded_shards.push_back(new_shard);
103 }
104 return expanded_shards;
105 };
106 // Iterates over the dimensions of the shard, expanding at each
107 // dimension.
108 std::vector<Shard> total_expanded_shards = shard_vec.shards;
109 for (int dim_ind = 0; dim_ind < new_num_shards_per_dim.size(); ++dim_ind) {
110 std::vector<Shard> dim_expanded_shards;
111 for (const auto& shard : total_expanded_shards) {
112 std::vector<Shard> expanded_shards = ExpandShard(shard, dim_ind);
113 // Concatenate newly created shards.
114 dim_expanded_shards.insert(dim_expanded_shards.end(),
115 expanded_shards.begin(),
116 expanded_shards.end());
117 }
118 // Copy newly created shards and delete old ones.
119 total_expanded_shards = dim_expanded_shards;
120 }
121 std::sort(total_expanded_shards.begin(), total_expanded_shards.end());
122 ShardVector expanded_shard_vec;
123 expanded_shard_vec.shards = total_expanded_shards;
124 expanded_shard_vec.num_shards_per_dim = new_num_shards_per_dim;
125 return expanded_shard_vec;
126 }
127 } // namespace
128
operator ==(const ShardVector & other) const129 bool ShardVector::operator==(const ShardVector& other) const {
130 // Check same number of shards.
131 if (this->shards.empty() && other.shards.empty()) return true;
132 if (this->shards.empty() || other.shards.empty()) return false;
133
134 // Check number of shard dimensions match.
135 if (this->num_shards_per_dim.size() != other.num_shards_per_dim.size())
136 return false;
137
138 // Compute lowest common multiple for each of the shard dimensions.
139 Shard first_shard_this = this->shards[0];
140 Shard first_shard_other = other.shards[0];
141 std::vector<int> new_sizes;
142 for (size_t i = 0; i < first_shard_this.size(); ++i) {
143 int lcm = this->num_shards_per_dim[i] * other.num_shards_per_dim[i] /
144 MathUtil::GCD(static_cast<unsigned>(this->num_shards_per_dim[i]),
145 static_cast<unsigned>(other.num_shards_per_dim[i]));
146 new_sizes.push_back(lcm);
147 }
148
149 // Expand and compare.
150 return ExpandShardVector(*this, new_sizes).shards ==
151 ExpandShardVector(other, new_sizes).shards;
152 }
153
ToString() const154 std::string ShardVector::ToString() const {
155 std::string string = "shards:[";
156 // Convert each Shard into string.
157 std::vector<std::string> shard_strs;
158 shard_strs.reserve(shards.size());
159 for (const Shard& shard : shards)
160 shard_strs.push_back("(" + absl::StrJoin(shard, ",") + ")");
161 // Join shards, and append dimensions.
162 absl::StrAppend(&string, absl::StrJoin(shard_strs, ","));
163 absl::StrAppend(&string, "] num_shards_per_dim:(");
164 absl::StrAppend(&string, absl::StrJoin(num_shards_per_dim, ",") + ")");
165 return string;
166 }
167
ContainsShard(const Shard & shard) const168 bool ShardVector::ContainsShard(const Shard& shard) const {
169 for (const auto& shard_in_vec : shards)
170 if (shard_in_vec == shard) return true;
171 return false;
172 }
173
174 // static
tpu_core_ids()175 std::map<std::string, std::vector<int>>& Mesh::tpu_core_ids() {
176 static auto tpu_core_ids = new std::map<std::string, std::vector<int>>();
177 return *tpu_core_ids;
178 }
179
180 // static
tpu_host_mesh()181 std::string& Mesh::tpu_host_mesh() {
182 static auto tpu_host_mesh = new std::string;
183 return *tpu_host_mesh;
184 }
185
186 // static
ParseFromProto(const MeshProto & proto)187 StatusOr<Mesh> Mesh::ParseFromProto(const MeshProto& proto) {
188 Mesh mesh;
189 mesh.name_ = proto.name();
190
191 for (const auto& device : proto.local_devices()) {
192 mesh.local_devices_.push_back(device);
193 }
194
195 // Define local device ids.
196 for (const auto& device_id : proto.local_device_ids()) {
197 mesh.local_device_ids_.push_back(device_id);
198 }
199
200 for (const auto& device_id : proto.global_device_ids()) {
201 mesh.global_device_ids_.push_back(device_id);
202 }
203
204 for (const auto& device : proto.global_devices()) {
205 mesh.global_devices_.push_back(device);
206 }
207
208 // Assign Mesh Dimensions.
209 mesh.mesh_dims_.resize(proto.mesh_dimensions_size());
210 for (int i = 0; i < proto.mesh_dimensions_size(); ++i) {
211 const MeshDimensionProto& dim = proto.mesh_dimensions(i);
212 mesh.mesh_dims_[i].name = dim.name();
213 mesh.mesh_dims_[i].size = dim.size();
214 }
215
216 // Check invariants.
217 int64 mesh_size = mesh.size();
218 int num_devices = proto.global_device_ids_size();
219 if (mesh_size > 0 && mesh_size != num_devices) {
220 TF_RETURN_WITH_CONTEXT(
221 errors::InvalidArgument("Number of devices ", num_devices,
222 " not matching mesh size ", mesh_size));
223 }
224 return mesh;
225 }
226
227 // static
GetAbstractMesh(const std::string & name,const std::vector<MeshDimension> & mesh_dims)228 StatusOr<Mesh> Mesh::GetAbstractMesh(
229 const std::string& name, const std::vector<MeshDimension>& mesh_dims) {
230 Mesh mesh;
231 mesh.name_ = name;
232 mesh.mesh_dims_ = mesh_dims;
233
234 // Check no repeated mesh dimension names.
235 std::set<std::string> dims_set;
236 for (const MeshDimension& dim : mesh.dims()) {
237 if (dims_set.find(dim.name) != dims_set.end())
238 TF_RETURN_WITH_CONTEXT(
239 errors::InvalidArgument("repeated mesh dimension"));
240 if (dim.name == Layout::kAny || dim.name == Layout::kMatch ||
241 dim.name == Layout::kUnshardedDim)
242 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument("mesh dimension name ",
243 dim.name, " is reserved"));
244 dims_set.insert(dim.name);
245 }
246
247 return mesh;
248 }
249
250 // static
GetMesh(const std::string & name,const std::vector<MeshDimension> & mesh_dims,const std::vector<std::int64_t> & global_device_ids,const std::vector<std::int64_t> & local_device_ids,const std::vector<std::string> & local_devices,const std::vector<std::string> & global_devices)251 StatusOr<Mesh> Mesh::GetMesh(const std::string& name,
252 const std::vector<MeshDimension>& mesh_dims,
253 const std::vector<std::int64_t>& global_device_ids,
254 const std::vector<std::int64_t>& local_device_ids,
255 const std::vector<std::string>& local_devices,
256 const std::vector<std::string>& global_devices) {
257 TF_ASSIGN_OR_RETURN(Mesh mesh, GetAbstractMesh(name, mesh_dims));
258 mesh.global_device_ids_ = global_device_ids;
259 mesh.local_device_ids_ = local_device_ids;
260 mesh.local_devices_ = local_devices;
261 mesh.global_devices_ = global_devices;
262
263 // Check number of devices matches conditions.
264 size_t global_n = mesh.global_device_ids_.size();
265 size_t local_n = mesh.local_device_ids_.size();
266 size_t dev_n = mesh.local_devices_.size();
267
268 if (!(global_n >= local_n && dev_n == local_n))
269 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
270 "number of global_device_ids ", std::to_string(global_n),
271 " local_devices ids ", std::to_string(local_n), " and local devices ",
272 std::to_string(dev_n), "not meeting requirements"));
273
274 // If empty device list, return empty mesh.
275 if (global_n == 0) return Mesh::Empty();
276
277 if (local_n && !(global_n % local_n == 0))
278 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
279 "Uneven local clusters with global_ids ", std::to_string(global_n),
280 " and local_devices ids ", std::to_string(local_n)));
281
282 // Check mesh size matches number of devices.
283 if (mesh.size() != global_n)
284 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument("mesh size doesn't match",
285 "number of devices"));
286
287 // Check local device invariants.
288 TF_ASSIGN_OR_RETURN(const auto& parsed_devs, mesh.ParsedDevices());
289 std::set<std::string> types_set;
290 for (const DeviceNameUtils::ParsedName& dev : parsed_devs) {
291 if (!dev.has_job || !dev.has_task || !dev.has_type)
292 return errors::InvalidArgument(
293 "Failed to either identify host or device type");
294 types_set.insert(dev.type);
295 if (types_set.size() > 1)
296 return errors::InvalidArgument(
297 "More than one device type per mesh not supported. Found ",
298 types_set.size());
299 }
300
301 return mesh;
302 }
303
dim_size(absl::string_view name) const304 StatusOr<int64_t> Mesh::dim_size(absl::string_view name) const {
305 for (const auto& mesh_dim : dims()) {
306 if (name == mesh_dim.name) {
307 return mesh_dim.size;
308 }
309 }
310
311 std::vector<std::string> dim_names;
312 for (const auto& mesh_dim : dims()) dim_names.push_back(mesh_dim.name);
313
314 return errors::NotFound(
315 "Dimension ", name, " does not exist in mesh.",
316 "Available dimensions: ", absl::StrJoin(dim_names, ","));
317 }
318
dim_sizes() const319 std::vector<int64_t> Mesh::dim_sizes() const {
320 std::vector<int64_t> dim_sizes;
321 if (mesh_dims_.empty()) return dim_sizes;
322 for (const auto& mesh_dim : mesh_dims_) dim_sizes.push_back(mesh_dim.size);
323 return dim_sizes;
324 }
325
operator ==(const Mesh & b) const326 bool Mesh::operator==(const Mesh& b) const {
327 return protobuf::util::MessageDifferencer::Equals(ToProto(), b.ToProto());
328 }
329
IsEmpty() const330 bool Mesh::IsEmpty() const { return global_device_ids_.empty(); }
331
ParsedDevices() const332 StatusOr<const std::vector<DeviceNameUtils::ParsedName>> Mesh::ParsedDevices()
333 const {
334 std::vector<DeviceNameUtils::ParsedName> parsed_devices(
335 local_devices_.size());
336 for (std::size_t i = 0; i < local_devices_.size(); ++i)
337 if (!DeviceNameUtils::ParseFullOrLocalName(
338 absl::string_view(local_devices_[i]), &parsed_devices[i]))
339 return errors::InvalidArgument("Failed to parse local_devices");
340
341 return parsed_devices;
342 }
343
ToDeviceType(const std::string & device_type) const344 StatusOr<Mesh> Mesh::ToDeviceType(const std::string& device_type) const {
345 std::vector<std::string> to_local_devices;
346 DeviceNameUtils::ParsedName parsed_dev;
347 for (const std::string& local_dev : local_devices_) {
348 if (!DeviceNameUtils::ParseFullOrLocalName(absl::string_view(local_dev),
349 &parsed_dev)) {
350 return errors::InvalidArgument("Failed to parse local devices");
351 }
352 // Converted mesh using full task name with job, replica and task ids.
353 to_local_devices.push_back(
354 DeviceNameUtils::FullName(parsed_dev.job, parsed_dev.replica,
355 parsed_dev.task, device_type, parsed_dev.id));
356 parsed_dev.Clear();
357 }
358 return GetMesh(name_, mesh_dims_, global_device_ids_, local_device_ids_,
359 to_local_devices, /*global_devices=*/{});
360 }
361
362 namespace {
HostFromParsedDev(const DeviceNameUtils::ParsedName & dev)363 std::string HostFromParsedDev(const DeviceNameUtils::ParsedName& dev) {
364 return "/job:" + dev.job + "/task:" + std::to_string(dev.task);
365 }
366 } // namespace
367
hosts() const368 std::vector<std::string> Mesh::hosts() const {
369 std::vector<std::string> host_list;
370 if (IsEmpty()) return host_list;
371
372 const auto parsed_devices = ParsedDevices().ValueOrDie();
373 for (const DeviceNameUtils::ParsedName& dev : parsed_devices) {
374 std::string host = HostFromParsedDev(dev);
375 if (std::find(host_list.begin(), host_list.end(), host) == host_list.end())
376 host_list.push_back(host);
377 }
378 return host_list;
379 }
380
device_type() const381 std::string Mesh::device_type() const {
382 if (IsEmpty()) return std::string();
383 std::string device;
384 if (!global_devices_.empty()) {
385 device = global_devices_[0];
386 } else {
387 device = local_devices_[0];
388 }
389 DeviceNameUtils::ParsedName dev;
390 DeviceNameUtils::ParseFullOrLocalName(device, &dev);
391 return dev.type;
392 }
393
IsMeshDim(const std::string & dim_name) const394 bool Mesh::IsMeshDim(const std::string& dim_name) const {
395 for (const auto& mesh_dim : dims())
396 if (dim_name == mesh_dim.name) return true;
397 return false;
398 }
399
GetMeshDimIndexWithName(const std::string & mesh_name) const400 int Mesh::GetMeshDimIndexWithName(const std::string& mesh_name) const {
401 int mesh_index = -1;
402 for (int i = 0; i < dims().size(); ++i) {
403 const auto mesh_dim = dim(i);
404 if (mesh_dim.name == mesh_name) mesh_index = i;
405 }
406 assert(mesh_index >= 0);
407 return mesh_index;
408 }
409
rank() const410 int64 Mesh::rank() const { return mesh_dims_.size(); }
411
size() const412 int64 Mesh::size() const {
413 if (mesh_dims_.empty()) return 0;
414
415 int64 size = 1;
416 for (const MeshDimension& dim : mesh_dims_) size *= dim.size;
417 return size;
418 }
419
Empty()420 Mesh Mesh::Empty() { return Mesh(); }
421
ToProto() const422 MeshProto Mesh::ToProto() const {
423 MeshProto mesh_proto;
424 mesh_proto.set_name(name());
425
426 for (const auto& d : local_devices_) {
427 mesh_proto.add_local_devices(d);
428 }
429
430 for (const auto& i : local_device_ids_) {
431 mesh_proto.add_local_device_ids(i);
432 }
433
434 for (const auto& i : global_device_ids_) {
435 mesh_proto.add_global_device_ids(i);
436 }
437
438 for (const auto& dim : mesh_dims_) {
439 MeshDimensionProto* mesh_dim_proto = mesh_proto.add_mesh_dimensions();
440 mesh_dim_proto->set_name(dim.name);
441 mesh_dim_proto->set_size(dim.size);
442 }
443
444 for (const auto& d : global_devices_) {
445 mesh_proto.add_global_devices(d);
446 }
447 return mesh_proto;
448 }
449
ToString() const450 std::string Mesh::ToString() const {
451 if (Mesh::IsEmpty()) return kEmptyMeshString;
452
453 // We use "|" to separate name, mesh dimensions and devices.
454 std::string mesh_str = absl::StrCat(Mesh::name(), "|");
455
456 // Add mesh dimensions
457 absl::InlinedVector<std::string, 4> mesh_dim_lst;
458 for (const auto& dim : mesh_dims_)
459 mesh_dim_lst.push_back(absl::StrCat(dim.name, "=", dim.size));
460 mesh_str += absl::StrJoin(mesh_dim_lst, ",") + "|";
461
462 // Add flattened list of global device ids
463 mesh_str += absl::StrJoin(global_device_ids_, ",") + "|";
464
465 // Add flattened list of local device ids
466 mesh_str += absl::StrJoin(local_device_ids_, ",") + "|";
467
468 // Add flattened list of local devices
469 mesh_str += absl::StrJoin(local_devices_, ",");
470
471 if (!global_devices_.empty()) {
472 // Add flattened list of global devices
473 mesh_str += "|";
474 mesh_str += absl::StrJoin(global_devices_, ",");
475 }
476 return mesh_str;
477 }
478
GlobalFingerprint() const479 uint64 Mesh::GlobalFingerprint() const {
480 if (Mesh::IsEmpty()) return Fingerprint64(kEmptyMeshString);
481
482 std::string mesh_str;
483 // Add mesh dimensions
484 absl::InlinedVector<std::string, 4> mesh_dim_lst;
485 for (const auto& dim : mesh_dims_)
486 mesh_dim_lst.push_back(absl::StrCat(dim.name, "=", dim.size));
487 mesh_str += absl::StrJoin(mesh_dim_lst, ",") + "|";
488
489 // Ignore local_device_ids_, local_devices and name which might be not global
490 // unique.
491 // Add flattened list of global device ids
492 mesh_str += absl::StrJoin(global_device_ids_, ",") + "|";
493
494 if (!global_devices_.empty()) {
495 // Add flattened list of global devices
496 mesh_str += "|";
497 mesh_str += absl::StrJoin(global_devices_, ",");
498 }
499 // mesh dims | global device ids (| global devices)
500 return Fingerprint64(mesh_str);
501 }
502
503 namespace {
StrToMeshDimension(const std::string & str)504 MeshDimension StrToMeshDimension(const std::string& str) {
505 MeshDimension mesh_dim;
506 if (str.empty()) return mesh_dim;
507
508 std::vector<std::string> mesh_dim_parts = absl::StrSplit(str, '=');
509
510 mesh_dim.name = mesh_dim_parts[0];
511 mesh_dim.size = std::stoi(mesh_dim_parts[1]);
512 return mesh_dim;
513 }
514
GenerateMeshDevicesForTests(const std::string & name,const std::vector<MeshDimension> & mesh_dims,const std::string & mesh_gen_instruction)515 StatusOr<Mesh> GenerateMeshDevicesForTests(
516 const std::string& name, const std::vector<MeshDimension>& mesh_dims,
517 const std::string& mesh_gen_instruction) {
518 // Parse mesh generation instruction.
519 std::vector<std::string> instruction_parts =
520 absl::StrSplit(mesh_gen_instruction, '*');
521 if (instruction_parts.size() != 2)
522 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
523 "Expected a * in mesh_gen_instructions but found ",
524 mesh_gen_instruction));
525 std::string device_type = instruction_parts[1];
526
527 // Get Mesh Size.
528 int64 mesh_size = 0;
529 if (!mesh_dims.empty()) {
530 mesh_size = 1;
531 for (const MeshDimension& mesh_dim : mesh_dims) mesh_size *= mesh_dim.size;
532 }
533
534 // Generate device ids.
535 std::vector<int64_t> global_device_ids;
536 std::vector<int64_t> local_device_ids;
537 std::vector<std::string> local_devices;
538 for (std::size_t i = 0; i < mesh_size; ++i) {
539 global_device_ids.push_back(i);
540 local_device_ids.push_back(i);
541 local_devices.push_back("/job:localhost/task:0/device:" + device_type +
542 ":" + std::to_string(i));
543 }
544
545 TF_ASSIGN_OR_RETURN(
546 Mesh mesh,
547 Mesh::GetMesh(name, mesh_dims, global_device_ids, local_device_ids,
548 local_devices, /*global_devices=*/{}));
549 return mesh;
550 }
551 } // namespace
552
553 // static
FromString(const std::string & str)554 StatusOr<Mesh> Mesh::FromString(const std::string& str) {
555 if (str == kEmptyMeshString) return Mesh::Empty();
556
557 std::vector<std::string> mesh_parts = absl::StrSplit(str, '|');
558
559 // Check formatting error.
560 if (mesh_parts.size() != 3 && mesh_parts.size() != 5 &&
561 mesh_parts.size() != 6)
562 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
563 "Expected either 5, 6 or 3 mesh parts but found", mesh_parts.size()));
564
565 // Populate mesh.
566 std::string name = mesh_parts[0];
567
568 // Add mesh dimensions.
569 std::vector<MeshDimension> mesh_dims;
570 if (!mesh_parts[1].empty()) {
571 std::vector<std::string> mesh_dim_strs = absl::StrSplit(mesh_parts[1], ',');
572 mesh_dims.reserve(mesh_dim_strs.size());
573 for (const std::string& mesh_dim_str : mesh_dim_strs)
574 mesh_dims.push_back(StrToMeshDimension(mesh_dim_str));
575 }
576
577 // Check if mesh is set to be autogenerated.
578 if (mesh_parts.size() == 3)
579 return GenerateMeshDevicesForTests(name, mesh_dims, mesh_parts[2]);
580
581 // Add global device ids list.
582 std::vector<int64_t> global_device_ids;
583 if (!mesh_parts[2].empty()) {
584 std::vector<std::string> global_device_ids_strs =
585 absl::StrSplit(mesh_parts[2], ',');
586
587 global_device_ids.reserve(global_device_ids_strs.size());
588 for (const std::string& id : global_device_ids_strs)
589 global_device_ids.push_back(std::stoi(id));
590 }
591
592 // Add local device ids list.
593 std::vector<int64_t> local_device_ids;
594 if (!mesh_parts[3].empty()) {
595 std::vector<std::string> local_device_ids_strs =
596 absl::StrSplit(mesh_parts[3], ',');
597
598 local_device_ids.reserve(local_device_ids_strs.size());
599 for (const std::string& id : local_device_ids_strs)
600 local_device_ids.push_back(std::stoi(id));
601 }
602 // Add local devices.
603 std::vector<std::string> local_devices;
604 if (!mesh_parts[4].empty())
605 local_devices = absl::StrSplit(mesh_parts[4], ',');
606
607 std::vector<std::string> global_devices;
608 if (mesh_parts.size() == 6) {
609 // Add global devices.
610 if (!mesh_parts[5].empty())
611 global_devices = absl::StrSplit(mesh_parts[5], ',');
612 }
613
614 TF_ASSIGN_OR_RETURN(
615 Mesh mesh,
616 Mesh::GetMesh(name, mesh_dims, global_device_ids, local_device_ids,
617 local_devices, global_devices));
618 return mesh;
619 }
620
num_devices() const621 int64 Mesh::num_devices() const { return global_device_ids_.size(); }
622
device_location(int offset) const623 StatusOr<const DeviceLocation> Mesh::device_location(int offset) const {
624 if (offset < 0 || offset > size() - 1)
625 return errors::InvalidArgument(
626 "Mesh offset cannot be negative or exceed Mesh's size. Offset size:",
627 offset, " and Mesh size:", size());
628
629 DeviceLocation dev_loc;
630 std::vector<int64> mesh_dim_lengths = dim_sizes();
631 int64 i = mesh_dim_lengths.size() - 1;
632 while (i >= 0) {
633 dev_loc.insert(dev_loc.begin(), offset % mesh_dim_lengths[i]);
634 offset /= mesh_dim_lengths[i];
635 --i;
636 }
637 return dev_loc;
638 }
639
GetFlattenedCoordinate(const DeviceLocation & loc) const640 int64 Mesh::GetFlattenedCoordinate(const DeviceLocation& loc) const {
641 const std::vector<int64> mesh_dim_sizes = dim_sizes();
642 int64 i = mesh_dim_sizes.size() - 1;
643 int64 acc = 1;
644 int64 device_pos = 0;
645 while (i >= 0) {
646 device_pos += loc[i] * acc;
647 acc *= mesh_dim_sizes[i];
648 --i;
649 }
650 return device_pos;
651 }
652
idx_for_dim(absl::string_view dim_name) const653 StatusOr<int32> Mesh::idx_for_dim(absl::string_view dim_name) const {
654 for (int i = 0; i < mesh_dims_.size(); ++i) {
655 if (mesh_dims_[i].name == dim_name) return i;
656 }
657 return errors::InvalidArgument("dim name :", dim_name,
658 " does not exist on mesh : ", ToString());
659 }
660
GetLayout(const std::vector<std::string> & sharding_spec_strs,const Mesh & mesh)661 StatusOr<Layout> Layout::GetLayout(
662 const std::vector<std::string>& sharding_spec_strs, const Mesh& mesh) {
663 // Re-format sharding specs.
664 std::vector<ShardingSpec> sharding_specs;
665 sharding_specs.reserve(sharding_spec_strs.size());
666 for (const std::string& spec_str : sharding_spec_strs) {
667 ShardingSpec spec;
668 spec.set_sharding_spec(spec_str);
669 sharding_specs.push_back(spec);
670 }
671 return GetLayout(sharding_specs, mesh);
672 }
673
GetLayout(const std::vector<ShardingSpec> & sharding_specs,const Mesh & mesh)674 StatusOr<Layout> Layout::GetLayout(
675 const std::vector<ShardingSpec>& sharding_specs, const Mesh& mesh) {
676 Layout layout;
677 // Append mesh, then check sharding_specs are legal.
678 layout.mesh_ = mesh;
679
680 // Check sharding_specs are either mesh dimension or special value.
681 for (const auto& dim : sharding_specs) {
682 const std::string& sharding_spec = dim.sharding_spec();
683 if (!(sharding_spec == kUnshardedDim || sharding_spec == kAny ||
684 sharding_spec == kMatch || mesh.IsMeshDim(sharding_spec) ||
685 sharding_spec == "scalar"))
686 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
687 "sharding spec (", sharding_spec,
688 ") refers to mesh dimension not contained in mesh ",
689 mesh.ToString()));
690 }
691 // Check same tensor dimensions not sharded over same mesh dimension twice.
692 std::set<std::string> dims_set;
693 for (const auto& dim : sharding_specs) {
694 const std::string& sharding_spec = dim.sharding_spec();
695 if (sharding_spec == kUnshardedDim || sharding_spec == kAny) continue;
696 // If scalar, delete all sharding specs.
697 if (sharding_spec == "scalar") {
698 if (sharding_specs.size() > 1)
699 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
700 "A scalar sharding_spec can only be used as a single sharding_spec "
701 "instruction, not as part of list of sharding_specs as attempted "
702 "here with ",
703 sharding_specs.size(), " sharding_specs"))
704 // Return layout with empty spec to represent scalar behavior.
705 return layout;
706 }
707 if (dims_set.find(sharding_spec) != dims_set.end())
708 TF_RETURN_WITH_CONTEXT(
709 errors::InvalidArgument("Attempted to shard two or more tensor "
710 "dimensions over mesh dimension ",
711 sharding_spec))
712 dims_set.insert(sharding_spec);
713 }
714 // After checking sharding_specs are legal, append and return layout.
715 layout.sharding_specs_ = sharding_specs;
716 return layout;
717 }
718
Empty()719 Layout Layout::Empty() {
720 Layout result;
721 return result;
722 }
723
IsEmpty() const724 bool Layout::IsEmpty() const { return mesh_.IsEmpty(); }
725
726 namespace {
ReducedAbstractMesh(const Layout * layout)727 Mesh ReducedAbstractMesh(const Layout* layout) {
728 const std::vector<std::string>& shard_spec_strs =
729 layout->sharding_spec_strs();
730 std::vector<MeshDimension> reduced_mesh_dims;
731 reduced_mesh_dims.reserve(layout->mesh().dims().size());
732 for (const MeshDimension& mesh_dim : layout->mesh().dims()) {
733 bool IsMeshDimInShardingSpecs =
734 std::find(shard_spec_strs.begin(), shard_spec_strs.end(),
735 mesh_dim.name) != shard_spec_strs.end();
736 // If dimension not in sharding_spec, flip size to 1.
737 MeshDimension reduced_dim =
738 IsMeshDimInShardingSpecs ? mesh_dim : MeshDimension(mesh_dim.name, 1);
739 reduced_mesh_dims.push_back(reduced_dim);
740 }
741 return Mesh::GetAbstractMesh("", reduced_mesh_dims).ValueOrDie();
742 }
743
744 } // namespace
745
ReducedMesh() const746 Mesh Layout::ReducedMesh() const {
747 // Set replicated mesh dimensions to size 1, and create reduced abstract mesh.
748 Mesh reduced_mesh = ReducedAbstractMesh(this);
749
750 // Populate reduced mesh with global devices from original mesh.
751 std::vector<int64_t> reduced_global_device_ids;
752 std::vector<std::string> reduced_global_devs;
753 for (const DeviceLocation& loc : ComputeDeviceLocations(&reduced_mesh)) {
754 int64 pos = mesh().GetFlattenedCoordinate(loc);
755 reduced_global_device_ids.push_back(mesh().global_device_ids().at(pos));
756 if (!mesh().global_devices().empty()) {
757 reduced_global_devs.push_back(mesh().global_devices().at(pos));
758 }
759 }
760
761 // Track the set of global device IDs in the abstract mesh.
762 std::set<int64_t> reduced_global_device_ids_set(
763 reduced_global_device_ids.begin(), reduced_global_device_ids.end());
764
765 // Populate reduced mesh with local devices in the same order as the original
766 // mesh.
767 std::vector<int64_t> reduced_local_device_ids;
768 std::vector<std::string> reduced_local_devs;
769 for (size_t i = 0; i < mesh().local_device_ids().size(); ++i) {
770 int64_t device_id = mesh().local_device_ids().at(i);
771 if (reduced_global_device_ids_set.find(device_id) !=
772 reduced_global_device_ids_set.end()) {
773 reduced_local_device_ids.push_back(device_id);
774 reduced_local_devs.push_back(mesh().local_devices().at(i));
775 }
776 }
777
778 return Mesh::GetMesh(reduced_mesh.name(), reduced_mesh.dims(),
779 reduced_global_device_ids, reduced_local_device_ids,
780 reduced_local_devs, reduced_global_devs)
781 .ValueOrDie();
782 }
783
784 namespace {
ReducedLayout(const Layout * layout)785 Layout ReducedLayout(const Layout* layout) {
786 // Change format sharding specs.
787 std::vector<ShardingSpec> shard_specs(layout->sharding_specs().size());
788 for (size_t i = 0; i < shard_specs.size(); ++i)
789 shard_specs[i] = layout->dim(i);
790 // Retrieve layout.
791 return Layout::GetLayout(shard_specs, layout->ReducedMesh()).ValueOrDie();
792 }
793
794 // Returns index of the given mesh dimension or mesh dim size if not found.
IndexOfMeshDimension(const Mesh & mesh,const std::string & dim_name)795 StatusOr<int> IndexOfMeshDimension(const Mesh& mesh,
796 const std::string& dim_name) {
797 for (size_t i = 0; i < mesh.dims().size(); ++i)
798 if (dim_name == mesh.dims()[i].name) return i;
799 return errors::InvalidArgument("Mesh dimension not found");
800 }
801 } // namespace
802
GetShardVector() const803 ShardVector Layout::GetShardVector() const {
804 // Change format sharding specs.
805 std::vector<ShardingSpec> shard_specs(sharding_specs().size());
806 for (size_t i = 0; i < shard_specs.size(); ++i) shard_specs[i] = dim(i);
807 // Obtain a shard position (i.e. sharded section of a tensor) from a mesh
808 // location, using the sharding specs.
809 auto GetShardFromDeviceLocation = [&](const DeviceLocation& loc) -> Shard {
810 Shard shard;
811 for (size_t i = 0; i < shard_specs.size(); ++i) {
812 // If unsharded, there is only one shard, that is 1.
813 std::string spec = shard_specs[i].sharding_spec();
814 if (spec == Layout::kUnshardedDim) {
815 shard.push_back(1);
816 } else {
817 int mesh_index =
818 IndexOfMeshDimension(mesh(), sharding_spec(i)).ValueOrDie();
819 int shard_number = loc[mesh_index] + 1;
820 shard.push_back(shard_number);
821 }
822 }
823 return shard;
824 };
825 // Obtain dims of shard vector.
826 auto ShardVectorDims = [&]() -> std::vector<int> {
827 std::vector<int> num_shards_per_dim(shard_specs.size());
828 for (size_t i = 0; i < sharding_specs().size(); ++i) {
829 ShardingSpec spec = sharding_specs()[i];
830 if (Layout::IsShardedSpec(spec)) {
831 StatusOr<int64> dim_size = mesh().dim_size(spec.sharding_spec());
832 num_shards_per_dim[i] = dim_size.ValueOrDie();
833 } else {
834 num_shards_per_dim[i] = 1;
835 }
836 }
837 return num_shards_per_dim;
838 };
839 // Compute mesh locations and obtain shards from them.
840 ShardVector shard_vec;
841 for (const DeviceLocation& mesh_loc : ComputeDeviceLocations(&mesh()))
842 shard_vec.shards.push_back(GetShardFromDeviceLocation(mesh_loc));
843 // Calculate dims.
844 shard_vec.num_shards_per_dim = ShardVectorDims();
845 return shard_vec;
846 }
847
HostShardMap() const848 std::map<std::string, ShardVector> Layout::HostShardMap() const {
849 Layout reduced_layout = ReducedLayout(this);
850 Mesh reduced_mesh = reduced_layout.mesh();
851 using HostName = std::string;
852
853 // Build a map: {Host : Shards}
854 std::map<HostName, ShardVector> host_shards_map;
855 ShardVector shard_vec_in_red_layout = reduced_layout.GetShardVector();
856
857 const auto parsed_devs = reduced_mesh.ParsedDevices().ValueOrDie();
858 for (size_t i = 0; i < parsed_devs.size(); ++i) {
859 HostName host = HostFromParsedDev(parsed_devs[i]);
860 Shard shard_in_device = shard_vec_in_red_layout.shards[i];
861
862 // Check if host in hashtable and append shard.
863 auto it = host_shards_map.find(host);
864 if (it == host_shards_map.end()) {
865 ShardVector shard_vec_in_host;
866 shard_vec_in_host.shards.push_back(shard_in_device);
867 shard_vec_in_host.num_shards_per_dim =
868 shard_vec_in_red_layout.num_shards_per_dim;
869 host_shards_map.insert(
870 std::pair<HostName, ShardVector>(host, shard_vec_in_host));
871 } else {
872 bool isShardInShardVector = it->second.ContainsShard(shard_in_device);
873 if (!isShardInShardVector) {
874 it->second.shards.push_back(shard_in_device);
875 }
876 }
877 }
878 // Sort shards inside each host.
879 for (auto it = host_shards_map.begin(); it != host_shards_map.end(); ++it) {
880 std::sort(it->second.shards.begin(), it->second.shards.end());
881 }
882 return host_shards_map;
883 }
884
sharding_spec(int idx) const885 const std::string& Layout::sharding_spec(int idx) const {
886 return sharding_specs_[idx].sharding_spec();
887 }
888
num_shards() const889 std::vector<int32> Layout::num_shards() const {
890 std::vector<int32> num_shards;
891 num_shards.reserve(sharding_specs_.size());
892 for (const auto& sharding_spec : sharding_specs_) {
893 num_shards.push_back(num_shards_for_dim(sharding_spec));
894 }
895 return num_shards;
896 }
897
num_shards_for_dim(const ShardingSpec & dim) const898 size_t Layout::num_shards_for_dim(const ShardingSpec& dim) const {
899 absl::string_view name = dim.sharding_spec();
900 if (name == Layout::kUnshardedDim) return 1;
901 if (name == Layout::kMatch) return -1;
902
903 return mesh().dim_size(name).ValueOrDie();
904 }
905
IsFullyReplicated() const906 bool Layout::IsFullyReplicated() const {
907 for (const auto& sharding_spec : sharding_specs_) {
908 if (num_shards_for_dim(sharding_spec) > 1) {
909 return false;
910 }
911 }
912 return true;
913 }
914
IsLastDimReplicated() const915 bool Layout::IsLastDimReplicated() const {
916 return (sharding_specs_.empty()) ||
917 (num_shards_for_dim(sharding_specs_.back()) == 1);
918 }
919
IsBatchParallel() const920 bool Layout::IsBatchParallel() const {
921 if (sharding_specs_.empty()) {
922 return true;
923 }
924
925 for (int i = 1; i < sharding_specs_.size(); ++i) {
926 const auto& dim = sharding_specs_[i];
927 if (num_shards_for_dim(dim) != 1) {
928 return false;
929 }
930 }
931 return true;
932 }
933
934 // TODO(samuelslee) Replace this with the IsBatchParallel() everywhere
IsBatchParallel(int non_batch_rank) const935 bool Layout::IsBatchParallel(int non_batch_rank) const {
936 if (sharding_specs_.empty()) return true;
937 for (int i = rank() - non_batch_rank; i < rank(); ++i) {
938 if (num_shards_for_dim(sharding_specs_[i]) != 1) return false;
939 }
940 return true;
941 }
942
ToProto() const943 LayoutProto Layout::ToProto() const {
944 LayoutProto proto;
945 *proto.mutable_mesh_config() = mesh_.ToProto();
946 for (const auto& dim : sharding_specs_) {
947 *proto.add_sharding_specs() = dim;
948 }
949 return proto;
950 }
951
IsEquivalent(const Layout & b) const952 bool Layout::IsEquivalent(const Layout& b) const {
953 if (this->rank() != b.rank()) return false;
954 if (this->mesh() != b.mesh()) return false;
955 for (int i = 0; i < this->rank(); ++i) {
956 if (this->sharding_specs_[i].sharding_spec() !=
957 b.sharding_specs_[i].sharding_spec()) {
958 if ((this->num_shards_for_dim(this->sharding_specs_[i]) != 1) ||
959 (b.num_shards_for_dim(b.sharding_specs_[i]) != 1))
960 return false;
961 }
962 }
963 return true;
964 }
965
operator ==(const Layout & b) const966 bool Layout::operator==(const Layout& b) const {
967 return protobuf::util::MessageDifferencer::Equals(ToProto(), b.ToProto());
968 }
969
GlobalShapeFromLocalShape(const std::vector<int64_t> & local_shape) const970 std::vector<int64_t> Layout::GlobalShapeFromLocalShape(
971 const std::vector<int64_t>& local_shape) const {
972 if (IsFullyReplicated()) {
973 return local_shape;
974 }
975 std::vector<int64_t> global_shape;
976 global_shape.reserve(sharding_specs().size());
977 for (int i = 0; i < sharding_specs().size(); ++i) {
978 int64_t l_shape = local_shape.empty() ? 1 : local_shape[i];
979 int64_t dim_shards = num_shards()[i];
980 global_shape.emplace_back(l_shape * dim_shards);
981 }
982 return global_shape;
983 }
984
LocalShapeFromGlobalShape(absl::Span<const int64_t> global_shape) const985 std::vector<int64_t> Layout::LocalShapeFromGlobalShape(
986 absl::Span<const int64_t> global_shape) const {
987 if (IsFullyReplicated()) {
988 return std::vector<int64_t>(global_shape.begin(), global_shape.end());
989 }
990 std::vector<int32> shards = num_shards();
991 std::vector<int64_t> local_shape;
992 for (int i = 0; i < sharding_specs().size(); ++i) {
993 int64_t dim_shards = shards[i];
994 // TODO(hthu): Shape might not be always divisible.
995 local_shape.emplace_back(global_shape[i] / dim_shards);
996 }
997 return local_shape;
998 }
999
LocalShapeFromGlobalShape(const PartialTensorShape & global_shape) const1000 PartialTensorShape Layout::LocalShapeFromGlobalShape(
1001 const PartialTensorShape& global_shape) const {
1002 if (IsFullyReplicated() || global_shape.dims() == -1) {
1003 return global_shape;
1004 }
1005 std::vector<int32> shards = num_shards();
1006 PartialTensorShape local_shape({});
1007 for (int spec_index = 0; spec_index < sharding_specs().size(); ++spec_index) {
1008 int64_t dim_size = global_shape.dim_size(spec_index);
1009 local_shape.AddDim(dim_size == -1 ? -1 : dim_size / shards[spec_index]);
1010 }
1011 return local_shape;
1012 }
1013
FromProto(const LayoutProto & proto)1014 StatusOr<Layout> Layout::FromProto(const LayoutProto& proto) {
1015 Layout layout;
1016 for (const auto& spec : proto.sharding_specs())
1017 layout.sharding_specs_.push_back(spec);
1018
1019 TF_ASSIGN_OR_RETURN(auto mesh, Mesh::ParseFromProto(proto.mesh_config()));
1020 layout.mesh_ = std::move(mesh);
1021
1022 return GetLayout(layout.sharding_specs_, layout.mesh_);
1023 }
1024
ReplicatedOnMesh(const Mesh & mesh,int rank)1025 Layout Layout::ReplicatedOnMesh(const Mesh& mesh, int rank) {
1026 std::vector<std::string> specs(rank, kUnshardedDim);
1027 return Layout::GetLayout(specs, mesh).ValueOrDie();
1028 }
1029
AnyOnMesh(const Mesh & mesh,int rank)1030 Layout Layout::AnyOnMesh(const Mesh& mesh, int rank) {
1031 std::vector<std::string> specs(rank, kAny);
1032 return Layout::GetLayout(specs, mesh).ValueOrDie();
1033 }
1034
Transposed2D(const Layout & layout)1035 StatusOr<Layout> Layout::Transposed2D(const Layout& layout) {
1036 if (layout.rank() < 2) {
1037 return errors::InvalidArgument("Transposed2D requires rank to be >= 2");
1038 }
1039 std::vector<std::string> transposed_specs = layout.sharding_spec_strs();
1040 std::iter_swap(transposed_specs.end() - 2, transposed_specs.end() - 1);
1041 return Layout::GetLayout(transposed_specs, layout.mesh()).ValueOrDie();
1042 }
1043
1044 // static
FromString(std::string layout_str)1045 StatusOr<Layout> Layout::FromString(std::string layout_str) {
1046 if (layout_str == kEmptyLayoutString) return Layout::Empty();
1047
1048 // Print sharding specs.
1049 std::vector<absl::string_view> layout_parts = absl::StrSplit(layout_str, ' ');
1050 // Check formatting error.
1051 if (layout_parts.size() != 2) {
1052 TF_RETURN_WITH_CONTEXT(errors::InvalidArgument(
1053 "Expected 2 items but found ", layout_parts.size(), layout_parts[0]));
1054 }
1055 // Substract prefixes.
1056 absl::string_view sharding_spec_str = layout_parts[0];
1057 absl::ConsumePrefix(&sharding_spec_str, "sharding_specs:");
1058
1059 absl::string_view mesh_str = layout_parts[1];
1060 absl::ConsumePrefix(&mesh_str, "mesh:");
1061
1062 // Add sharding specs.
1063 std::vector<std::string> sharding_spec_strs =
1064 absl::StrSplit(sharding_spec_str, ',');
1065 sharding_spec_strs.pop_back();
1066
1067 // Add mesh.
1068 TF_ASSIGN_OR_RETURN(Mesh mesh, Mesh::FromString(string(mesh_str)));
1069 // Try to create layout.
1070 TF_ASSIGN_OR_RETURN(Layout layout,
1071 Layout::GetLayout(sharding_spec_strs, mesh));
1072 return layout;
1073 }
1074
sharding_spec_strs() const1075 std::vector<std::string> Layout::sharding_spec_strs() const {
1076 std::vector<std::string> sharding_spec_strs(sharding_specs().size());
1077 for (size_t i = 0; i < sharding_specs().size(); ++i)
1078 sharding_spec_strs[i] = sharding_spec(i);
1079 return sharding_spec_strs;
1080 }
1081
ToString() const1082 std::string Layout::ToString() const {
1083 if (Layout::IsEmpty()) return kEmptyLayoutString;
1084
1085 std::string layout_str = "sharding_specs:";
1086 // Print sharding specs.
1087 for (const ShardingSpec& dim : sharding_specs_) {
1088 std::string dim_name = dim.sharding_spec();
1089 absl::StrAppend(&layout_str, dim_name + ",");
1090 }
1091 // Append mesh.
1092 absl::StrAppend(&layout_str, " mesh:", mesh_.ToString());
1093 return layout_str;
1094 }
1095
GetLayoutWithReducedDims(const absl::flat_hash_set<int> & reduced_dims,bool keep_dims) const1096 Layout Layout::GetLayoutWithReducedDims(
1097 const absl::flat_hash_set<int>& reduced_dims, bool keep_dims) const {
1098 dtensor::LayoutProto output_layout;
1099 *output_layout.mutable_mesh_config() = mesh().ToProto();
1100
1101 for (int i = 0; i < rank(); ++i) {
1102 // reduced_dims may contain negative values.
1103 if (!reduced_dims.contains(i) && !reduced_dims.contains(i - rank())) {
1104 *output_layout.add_sharding_specs() = dim(i);
1105 } else if (keep_dims) {
1106 auto* replicated_dim = output_layout.add_sharding_specs();
1107 replicated_dim->set_sharding_spec(kUnshardedDim);
1108 }
1109 }
1110 return Layout::FromProto(output_layout).ValueOrDie();
1111 }
1112
Truncate(int64 split_point,bool end) const1113 Layout Layout::Truncate(int64 split_point, bool end) const {
1114 if ((split_point == 0 && end) || (split_point == rank() && !end))
1115 return *this;
1116
1117 dtensor::LayoutProto output_layout;
1118 *output_layout.mutable_mesh_config() = mesh().ToProto();
1119
1120 if (end) {
1121 for (int i = split_point; i < rank(); ++i)
1122 *output_layout.add_sharding_specs() = dim(i);
1123 } else {
1124 for (int i = 0; i < split_point; ++i)
1125 *output_layout.add_sharding_specs() = dim(i);
1126 }
1127 return Layout::FromProto(output_layout).ValueOrDie();
1128 }
1129
1130 namespace {
1131 // Adds unsharded sharding specs to layout.
PadLayout(const int64 rank,const bool is_padding_before,const Layout & layout)1132 Layout PadLayout(const int64 rank, const bool is_padding_before,
1133 const Layout& layout) {
1134 if (rank <= layout.rank()) return layout;
1135
1136 // Create list of padding sharding specs.
1137 const int n = rank - layout.rank();
1138 std::vector<ShardingSpec> new_specs(n);
1139 for (int i = 0; i < n; ++i)
1140 new_specs[i].set_sharding_spec(Layout::kUnshardedDim);
1141
1142 // Define concatenation point of layout specs.
1143 auto concat_point = is_padding_before ? new_specs.end() : new_specs.begin();
1144
1145 // Concatenate old layout specs and new unsharded specs.
1146 new_specs.insert(concat_point, layout.sharding_specs().begin(),
1147 layout.sharding_specs().end());
1148 return Layout::GetLayout(new_specs, layout.mesh()).ValueOrDie();
1149 }
1150 } // namespace
1151
LeftPad(int64 rank) const1152 Layout Layout::LeftPad(int64 rank) const {
1153 bool is_padding_before = true;
1154 return PadLayout(rank, is_padding_before, *this);
1155 }
1156
ConcatenateLayouts(const Layout & layout_a,const Layout & layout_b)1157 StatusOr<Layout> ConcatenateLayouts(const Layout& layout_a,
1158 const Layout& layout_b) {
1159 if (layout_a.mesh() != layout_b.mesh())
1160 return errors::InvalidArgument(
1161 "unable to concatenate layouts as they are on different meshes.");
1162
1163 absl::flat_hash_set<std::string> layout_a_mesh_dims;
1164 for (int i = 0; i < layout_a.rank(); ++i)
1165 if (layout_a.sharding_spec(i) != Layout::kUnshardedDim)
1166 layout_a_mesh_dims.emplace(layout_a.sharding_spec(i));
1167
1168 for (int i = 0; i < layout_b.rank(); ++i)
1169 if (layout_b.sharding_spec(i) != Layout::kUnshardedDim &&
1170 layout_a_mesh_dims.contains(layout_b.sharding_spec(i)))
1171 return errors::InvalidArgument(
1172 "unable to concatenate layouts as they use the same meshes "
1173 "dimension: ",
1174 layout_b.sharding_spec(i), " is used in both layouts.");
1175
1176 LayoutProto layout_proto_a = layout_a.ToProto();
1177 LayoutProto layout_proto_b = layout_b.ToProto();
1178 LayoutProto output_layout_proto;
1179
1180 *output_layout_proto.mutable_mesh_config() = layout_proto_a.mesh_config();
1181 for (int i = 0; i < layout_proto_a.sharding_specs_size(); ++i)
1182 *output_layout_proto.add_sharding_specs() =
1183 layout_proto_a.sharding_specs(i);
1184 for (int i = 0; i < layout_proto_b.sharding_specs_size(); ++i)
1185 *output_layout_proto.add_sharding_specs() =
1186 layout_proto_b.sharding_specs(i);
1187 return Layout::FromProto(output_layout_proto);
1188 }
1189
1190 } // namespace dtensor
1191 } // namespace tensorflow
1192