xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/tensor_layout.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_
17 #define TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <iostream>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/core/common_runtime/device_mgr.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/platform/statusor.h"
33 #include "tensorflow/dtensor/cc/dstatus.h"
34 #include "tensorflow/dtensor/proto/layout.pb.h"
35 #include "tensorflow/stream_executor/lib/statusor.h"
36 
37 // Definitions for DTensor mesh & layout.
38 //
39 // A mesh describes how a set of devices is partitioned.
40 // A layout describes how a distributed tensor is partitioned across a mesh (and
41 // thus across devices). Defining tensor layouts in terms of mesh dimensions
42 // allows us to efficiently determine the communication required when computing
43 // an operation with tensors of different layouts.
44 namespace tensorflow {
45 namespace dtensor {
46 
47 // The location of a device in a mesh.
48 //
49 // Each device has a unique location in the mesh, which is indicated by the
50 // offset in each mesh dimension. e.g. a mesh:
51 //
52 // [x:4, y:3, z:2]
53 //
54 // Must consist of 24 devices placed densely into the corresponding 3D space.
55 using DeviceLocation = absl::InlinedVector<int64, 4>;
56 
57 // A shard refers to a partition of a tensor. Shards are arranged in
58 // ShardVectors that contains a list of Shards and a list of integers
59 // representing the number of shards in each dimension.
60 //
61 // Example: layout = sharding_specs:x,y, mesh:|x=2,y=2|. This can be represented
62 // with a ShardVector:
63 //          - shards = (1,1), (1,2), (2,1), (2,2)
64 //          - num_shards_per_dim = (2,2).
65 //
66 // The number of elements in each shard matches the tensor rank.
67 using Shard = std::vector<int>;
68 
69 struct ShardVector {
70   bool operator==(const ShardVector& other) const;
71   bool operator!=(const ShardVector& other) const { return !(*this == other); }
72   std::string ToString() const;
73 
74   bool ContainsShard(const Shard& shard) const;
75 
76   std::vector<Shard> shards;
77   std::vector<int> num_shards_per_dim;
78 };
79 
80 struct MeshDimension {
MeshDimensionMeshDimension81   MeshDimension(const std::string& name, int64 size)
82       : name(std::move(name)), size(size) {}
83   MeshDimension() = default;
84 
85   std::string name;
86   int64 size;
87 };
88 
89 class Mesh {
90  public:
91   // Failed serialized strings are represented with en empty string, therefore
92   // we use this string representation of an empty mesh instead to avoid
93   // confusion.
94   static constexpr const char* kEmptyMeshString = "empty_mesh";
95   static Mesh Empty();
96   bool IsEmpty() const;
97   Mesh() = default;
98 
99   // Parses from MeshProto.
100   static StatusOr<Mesh> ParseFromProto(const MeshProto& proto);
101   // Parses from a human readable string version of the mesh, currently used
102   // to represent meshes in MLIR:
103   //  mesh = <name|List[MeshDim]|List[GlobalId]|List[LocalId]|List[Devices]>
104   //
105   // Example:
106   //  mesh =
107   //  <name|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3>
108   static StatusOr<Mesh> FromString(const std::string& str);
109   std::string ToString() const;
110   MeshProto ToProto() const;
111 
112   // Creates mesh without specific devices associated to it (aka abstract mesh).
113   // This is an experimental API. Use only if strictly needed.
114   static StatusOr<Mesh> GetAbstractMesh(
115       const std::string& name, const std::vector<MeshDimension>& mesh_dims);
116   // Creates fully defined mesh.
117   static StatusOr<Mesh> GetMesh(
118       const std::string& name, const std::vector<MeshDimension>& mesh_dims,
119       const std::vector<std::int64_t>& global_device_ids,
120       const std::vector<std::int64_t>& local_device_ids,
121       const std::vector<std::string>& local_devices,
122       const std::vector<std::string>& global_devices);
123 
is_cpu_mesh()124   bool is_cpu_mesh() const { return device_type() == "CPU"; }
is_epu_mesh()125   bool is_epu_mesh() const { return device_type() == "EPU"; }
is_tpu_mesh()126   bool is_tpu_mesh() const { return device_type() == "TPU"; }
127   // Returns whether the mesh is a remote mesh.
is_remote()128   bool is_remote() const {
129     return local_device_ids_.empty() && !global_device_ids_.empty();
130   }
131 
132   // Device information methods.
133   std::string device_type() const;
134   // Takes an index in the flattened list of devices and returns a location
135   // in the mesh.
136   StatusOr<const DeviceLocation> device_location(int offset) const;
137   int64 num_devices() const;
local_devices()138   absl::Span<const std::string> local_devices() const { return local_devices_; }
local_device_ids()139   absl::Span<const int64_t> local_device_ids() const {
140     return local_device_ids_;
141   }
142   // Parses names of local_devices according to TF's Device Name Utils.
143   StatusOr<const std::vector<DeviceNameUtils::ParsedName>> ParsedDevices()
144       const;
145   // Convert to given device type.
146   StatusOr<Mesh> ToDeviceType(const std::string& device_type) const;
147   std::vector<std::string> hosts() const;
148 
149   // Consumes a location in the mesh and returns its corresponding index in
150   // the flattened list of devices.
151   int64 GetFlattenedCoordinate(const DeviceLocation& loc) const;
152 
dim(int64 index)153   const MeshDimension& dim(int64 index) const { return mesh_dims_[index]; }
dims()154   std::vector<MeshDimension> dims() const { return mesh_dims_; }
155   // Returns size of mesh dimension.
156   StatusOr<int64> dim_size(absl::string_view name) const;
157   // Returns list of mesh dimension sizes.
158   std::vector<int64> dim_sizes() const;
dim_name(int64 index)159   const std::string& dim_name(int64 index) const {
160     return mesh_dims_[index].name;
161   }
min_global_device_id()162   int64_t min_global_device_id() const {
163     DCHECK(!global_device_ids_.empty());
164     return *std::min_element(global_device_ids_.begin(),
165                              global_device_ids_.end());
166   }
167 
global_device_ids()168   absl::Span<const int64_t> global_device_ids() const {
169     return global_device_ids_;
170   }
171 
global_devices()172   const std::vector<std::string>& global_devices() const {
173     return global_devices_;
174   }
175   // Returns index of given dim_name in the mesh.
176   StatusOr<int32> idx_for_dim(absl::string_view dim_name) const;
177 
178   // Returns the index of MeshDimension in mesh where the mesh dimension name is
179   // `mesh_name`.
180   int GetMeshDimIndexWithName(const std::string& mesh_name) const;
181   bool IsMeshDim(const std::string& dim_name) const;
182 
183   int64 rank() const;
184   int64 size() const;
name()185   const std::string& name() const { return name_; }
186 
187   // Global unique fingerprint. Same on different workers.
188   uint64 GlobalFingerprint() const;
189 
190   bool operator==(const Mesh& b) const;
191   bool operator!=(const Mesh& b) const { return !((*this) == b); }
192   bool operator<(const Mesh& b) const {
193     return this->ToString() < b.ToString();
194   }
195 
196   template <typename H>
AbslHashValue(H h,const Mesh & m)197   friend H AbslHashValue(H h, const Mesh& m) {
198     return H::combine(std::move(h), m.ToString());
199   }
200 
201   // A map from mesh names to their corresponding core ID mappings. The core ID
202   // mapping is stored as a vector. The i-th element in the vector is the ID of
203   // the core represented by global device ID of i in this mesh.
204   //
205   // The entry stored under the empty name key (the so-called "default mapping"
206   // in some comments) is special. It is always set at the end of TPU
207   // initialization. It represents the mapping for any mesh whose global device
208   // IDs follow TF task-device ordinals. Legacy and test meshes created without
209   // using the `create_tpu_mesh` helper follow that rule and can use this entry.
210   static std::map<std::string, std::vector<int>>& tpu_core_ids();
211 
212   // The host mesh associated with any user-defined TPU mesh.
213   static std::string& tpu_host_mesh();
214 
215  private:
216   std::string name_;
217   std::vector<MeshDimension> mesh_dims_;
218   std::vector<std::string> local_devices_;
219   std::vector<int64_t> local_device_ids_;
220   std::vector<int64_t> global_device_ids_;
221   std::vector<std::string> global_devices_;
222 };
223 
224 class Layout {
225  public:
226   static constexpr const char* kUnshardedDim = "unsharded";
227   // This spec should only be used to express no preferred sharding in the
228   // Layout propagation algorithm.
229   static constexpr const char* kAny = "any";
230   // Failed serialized strings are represented with en empty string, therefore
231   // we use this string representation of an empty layout instead to avoid
232   // confusion.
233   static constexpr const char* kEmptyLayoutString = "empty_layout";
234   // Used for the relayout operation, to allow relayout act as an identity on
235   // the layout for the given dimension.
236   static constexpr const char* kMatch = "match";
237 
238   // Returns empty layout.
239   static Layout Empty();
240 
241   // Parses from LayoutProto.
242   static StatusOr<Layout> FromProto(const LayoutProto& proto);
243   // Parses from a human readable string version of the layout, currently used
244   // to represent layouts in MLIR:
245   //  layout = <sharding_specs:List[specs] mesh:name|List[MeshDim]|
246   //  List[GlobalId]|List[LocalId]|List[Devices]>
247   //
248   // Example:
249   //  layout = <sharding_specs:x,not_sharded mesh:name|x=2,y=2|0,1,2,3|0,1,2,3|
250   //  /job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,
251   //  /job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3>
252   static StatusOr<Layout> FromString(std::string layout_str);
253   // Creates human readable string version of a layout.
254   std::string ToString() const;
255   LayoutProto ToProto() const;
256 
mesh()257   const Mesh& mesh() const { return mesh_; }
258   static Layout ReplicatedOnMesh(const Mesh& mesh, int rank);
259   static Layout AnyOnMesh(const Mesh& mesh, int rank);
260   // Creates a mesh of unique shards.
261   Mesh ReducedMesh() const;
set_mesh(Mesh mesh)262   void set_mesh(Mesh mesh) { mesh_ = mesh; }
263 
264   // Returns a layout for the transposed matrix for given layout. This assumes
265   // that only the last two dimensions are used for matrix computation and all
266   // dimensions before are batch dimensions.
267   static StatusOr<Layout> Transposed2D(const Layout& layout);
IsUnshardedDimension(const absl::string_view name)268   static bool IsUnshardedDimension(const absl::string_view name) {
269     return name == kUnshardedDim;
270   }
IsShardedDimension(const absl::string_view name)271   static bool IsShardedDimension(const absl::string_view name) {
272     return !IsUnshardedDimension(name);
273   }
IsUnshardedSpec(const ShardingSpec & spec)274   static bool IsUnshardedSpec(const ShardingSpec& spec) {
275     return IsUnshardedDimension(spec.sharding_spec());
276   }
IsShardedSpec(const ShardingSpec & spec)277   static bool IsShardedSpec(const ShardingSpec& spec) {
278     return !IsUnshardedDimension(spec.sharding_spec());
279   }
280   static StatusOr<Layout> GetLayout(
281       const std::vector<std::string>& sharding_spec_strs, const Mesh& mesh);
282   static StatusOr<Layout> GetLayout(
283       const std::vector<ShardingSpec>& sharding_specs, const Mesh& mesh);
284 
285   // Makes a new layout from this one dropping the given dimensions.
286   // If keep_dims is true, the dimensions are replicated rather than
287   // deleted.
288   Layout GetLayoutWithReducedDims(const absl::flat_hash_set<int>& reduced_dims,
289                                   bool keep_dims) const;
290 
291   // Truncates a layout at the front or back, depending on the value of end.
292   // end = false returns the layout upto the split point,
293   // end = true returns the layout from the split point.
294   Layout Truncate(int64 split_point, bool end = false) const;
295 
296   // Left or right pad the layout to a max rank.
297   Layout LeftPad(int64 rank) const;
298 
299   bool IsFullyReplicated() const;
300   bool IsLastDimReplicated() const;
301   // Checks that the last N-1 dimensions are replicated
302   bool IsBatchParallel() const;
303   // Checks that the dimensions from [-non_batch_rank, end) are replicaed
304   bool IsBatchParallel(int non_batch_rank) const;
305   bool IsEmpty() const;
306 
307   // Compute global shape using the layout and provided local_shape.
308   std::vector<int64_t> GlobalShapeFromLocalShape(
309       const std::vector<int64_t>& local_shape) const;
310 
311   std::vector<int64_t> LocalShapeFromGlobalShape(
312       absl::Span<const int64_t> global_shape) const;
313   PartialTensorShape LocalShapeFromGlobalShape(
314       const PartialTensorShape& global_shape) const;
315 
rank()316   int64 rank() const { return sharding_specs_.size(); }
317   size_t num_shards_for_dim(const ShardingSpec& dim) const;
318   std::vector<int32> num_shards() const;
319 
dim(int64 idx)320   const ShardingSpec& dim(int64 idx) const { return sharding_specs_[idx]; }
sharding_specs()321   absl::Span<const ShardingSpec> sharding_specs() const {
322     return sharding_specs_;
323   }
324 
325   // Computes the corresponding shard vector to this layout.
326   ShardVector GetShardVector() const;
327 
328   // Returns sharding specs in string form.
329   std::vector<std::string> sharding_spec_strs() const;
330 
num_devices()331   int64 num_devices() const { return mesh_.num_devices(); }
device_location(int64 device_id)332   StatusOr<const DeviceLocation> device_location(int64 device_id) const {
333     return mesh_.device_location(device_id);
334   }
335   // Map hosts to shards.
336   std::map<std::string, ShardVector> HostShardMap() const;
337 
338   const std::string& sharding_spec(int idx) const;
339 
340   // Two layouts are equivalent if they would result in the same sharding for
341   // the tensor. E.g. if on is unsharded and the other is sharded on a mesh
342   // dimension of size 1.
343   bool IsEquivalent(const Layout& b) const;
344   bool operator==(const Layout& b) const;
345   bool operator!=(const Layout& b) const { return !((*this) == b); }
346   bool operator<(const Layout& b) const {
347     return this->ToString() < b.ToString();
348   }
349 
350  private:
351   std::vector<ShardingSpec> sharding_specs_;
352   Mesh mesh_;
353 };
354 
355 // Takes two layouts and concatenates their TensorDimensions. If the meshes for
356 // the two layouts are different or both layouts are using the same mesh
357 // dimension returns an error rather than a layout.
358 StatusOr<Layout> ConcatenateLayouts(const Layout& layout_a,
359                                     const Layout& layout_b);
360 
361 }  // namespace dtensor
362 }  // namespace tensorflow
363 
364 #endif  // TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_
365