1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_ 17 #define TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_ 18 19 #include <algorithm> 20 #include <cstdint> 21 #include <iostream> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "absl/container/flat_hash_map.h" 27 #include "absl/container/flat_hash_set.h" 28 #include "absl/container/inlined_vector.h" 29 #include "absl/strings/string_view.h" 30 #include "tensorflow/core/common_runtime/device_mgr.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/platform/statusor.h" 33 #include "tensorflow/dtensor/cc/dstatus.h" 34 #include "tensorflow/dtensor/proto/layout.pb.h" 35 #include "tensorflow/stream_executor/lib/statusor.h" 36 37 // Definitions for DTensor mesh & layout. 38 // 39 // A mesh describes how a set of devices is partitioned. 40 // A layout describes how a distributed tensor is partitioned across a mesh (and 41 // thus across devices). Defining tensor layouts in terms of mesh dimensions 42 // allows us to efficiently determine the communication required when computing 43 // an operation with tensors of different layouts. 44 namespace tensorflow { 45 namespace dtensor { 46 47 // The location of a device in a mesh. 48 // 49 // Each device has a unique location in the mesh, which is indicated by the 50 // offset in each mesh dimension. e.g. a mesh: 51 // 52 // [x:4, y:3, z:2] 53 // 54 // Must consist of 24 devices placed densely into the corresponding 3D space. 55 using DeviceLocation = absl::InlinedVector<int64, 4>; 56 57 // A shard refers to a partition of a tensor. Shards are arranged in 58 // ShardVectors that contains a list of Shards and a list of integers 59 // representing the number of shards in each dimension. 60 // 61 // Example: layout = sharding_specs:x,y, mesh:|x=2,y=2|. This can be represented 62 // with a ShardVector: 63 // - shards = (1,1), (1,2), (2,1), (2,2) 64 // - num_shards_per_dim = (2,2). 65 // 66 // The number of elements in each shard matches the tensor rank. 67 using Shard = std::vector<int>; 68 69 struct ShardVector { 70 bool operator==(const ShardVector& other) const; 71 bool operator!=(const ShardVector& other) const { return !(*this == other); } 72 std::string ToString() const; 73 74 bool ContainsShard(const Shard& shard) const; 75 76 std::vector<Shard> shards; 77 std::vector<int> num_shards_per_dim; 78 }; 79 80 struct MeshDimension { MeshDimensionMeshDimension81 MeshDimension(const std::string& name, int64 size) 82 : name(std::move(name)), size(size) {} 83 MeshDimension() = default; 84 85 std::string name; 86 int64 size; 87 }; 88 89 class Mesh { 90 public: 91 // Failed serialized strings are represented with en empty string, therefore 92 // we use this string representation of an empty mesh instead to avoid 93 // confusion. 94 static constexpr const char* kEmptyMeshString = "empty_mesh"; 95 static Mesh Empty(); 96 bool IsEmpty() const; 97 Mesh() = default; 98 99 // Parses from MeshProto. 100 static StatusOr<Mesh> ParseFromProto(const MeshProto& proto); 101 // Parses from a human readable string version of the mesh, currently used 102 // to represent meshes in MLIR: 103 // mesh = <name|List[MeshDim]|List[GlobalId]|List[LocalId]|List[Devices]> 104 // 105 // Example: 106 // mesh = 107 // <name|x=2,y=2|0,1,2,3|0,1,2,3|/job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1,/job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3> 108 static StatusOr<Mesh> FromString(const std::string& str); 109 std::string ToString() const; 110 MeshProto ToProto() const; 111 112 // Creates mesh without specific devices associated to it (aka abstract mesh). 113 // This is an experimental API. Use only if strictly needed. 114 static StatusOr<Mesh> GetAbstractMesh( 115 const std::string& name, const std::vector<MeshDimension>& mesh_dims); 116 // Creates fully defined mesh. 117 static StatusOr<Mesh> GetMesh( 118 const std::string& name, const std::vector<MeshDimension>& mesh_dims, 119 const std::vector<std::int64_t>& global_device_ids, 120 const std::vector<std::int64_t>& local_device_ids, 121 const std::vector<std::string>& local_devices, 122 const std::vector<std::string>& global_devices); 123 is_cpu_mesh()124 bool is_cpu_mesh() const { return device_type() == "CPU"; } is_epu_mesh()125 bool is_epu_mesh() const { return device_type() == "EPU"; } is_tpu_mesh()126 bool is_tpu_mesh() const { return device_type() == "TPU"; } 127 // Returns whether the mesh is a remote mesh. is_remote()128 bool is_remote() const { 129 return local_device_ids_.empty() && !global_device_ids_.empty(); 130 } 131 132 // Device information methods. 133 std::string device_type() const; 134 // Takes an index in the flattened list of devices and returns a location 135 // in the mesh. 136 StatusOr<const DeviceLocation> device_location(int offset) const; 137 int64 num_devices() const; local_devices()138 absl::Span<const std::string> local_devices() const { return local_devices_; } local_device_ids()139 absl::Span<const int64_t> local_device_ids() const { 140 return local_device_ids_; 141 } 142 // Parses names of local_devices according to TF's Device Name Utils. 143 StatusOr<const std::vector<DeviceNameUtils::ParsedName>> ParsedDevices() 144 const; 145 // Convert to given device type. 146 StatusOr<Mesh> ToDeviceType(const std::string& device_type) const; 147 std::vector<std::string> hosts() const; 148 149 // Consumes a location in the mesh and returns its corresponding index in 150 // the flattened list of devices. 151 int64 GetFlattenedCoordinate(const DeviceLocation& loc) const; 152 dim(int64 index)153 const MeshDimension& dim(int64 index) const { return mesh_dims_[index]; } dims()154 std::vector<MeshDimension> dims() const { return mesh_dims_; } 155 // Returns size of mesh dimension. 156 StatusOr<int64> dim_size(absl::string_view name) const; 157 // Returns list of mesh dimension sizes. 158 std::vector<int64> dim_sizes() const; dim_name(int64 index)159 const std::string& dim_name(int64 index) const { 160 return mesh_dims_[index].name; 161 } min_global_device_id()162 int64_t min_global_device_id() const { 163 DCHECK(!global_device_ids_.empty()); 164 return *std::min_element(global_device_ids_.begin(), 165 global_device_ids_.end()); 166 } 167 global_device_ids()168 absl::Span<const int64_t> global_device_ids() const { 169 return global_device_ids_; 170 } 171 global_devices()172 const std::vector<std::string>& global_devices() const { 173 return global_devices_; 174 } 175 // Returns index of given dim_name in the mesh. 176 StatusOr<int32> idx_for_dim(absl::string_view dim_name) const; 177 178 // Returns the index of MeshDimension in mesh where the mesh dimension name is 179 // `mesh_name`. 180 int GetMeshDimIndexWithName(const std::string& mesh_name) const; 181 bool IsMeshDim(const std::string& dim_name) const; 182 183 int64 rank() const; 184 int64 size() const; name()185 const std::string& name() const { return name_; } 186 187 // Global unique fingerprint. Same on different workers. 188 uint64 GlobalFingerprint() const; 189 190 bool operator==(const Mesh& b) const; 191 bool operator!=(const Mesh& b) const { return !((*this) == b); } 192 bool operator<(const Mesh& b) const { 193 return this->ToString() < b.ToString(); 194 } 195 196 template <typename H> AbslHashValue(H h,const Mesh & m)197 friend H AbslHashValue(H h, const Mesh& m) { 198 return H::combine(std::move(h), m.ToString()); 199 } 200 201 // A map from mesh names to their corresponding core ID mappings. The core ID 202 // mapping is stored as a vector. The i-th element in the vector is the ID of 203 // the core represented by global device ID of i in this mesh. 204 // 205 // The entry stored under the empty name key (the so-called "default mapping" 206 // in some comments) is special. It is always set at the end of TPU 207 // initialization. It represents the mapping for any mesh whose global device 208 // IDs follow TF task-device ordinals. Legacy and test meshes created without 209 // using the `create_tpu_mesh` helper follow that rule and can use this entry. 210 static std::map<std::string, std::vector<int>>& tpu_core_ids(); 211 212 // The host mesh associated with any user-defined TPU mesh. 213 static std::string& tpu_host_mesh(); 214 215 private: 216 std::string name_; 217 std::vector<MeshDimension> mesh_dims_; 218 std::vector<std::string> local_devices_; 219 std::vector<int64_t> local_device_ids_; 220 std::vector<int64_t> global_device_ids_; 221 std::vector<std::string> global_devices_; 222 }; 223 224 class Layout { 225 public: 226 static constexpr const char* kUnshardedDim = "unsharded"; 227 // This spec should only be used to express no preferred sharding in the 228 // Layout propagation algorithm. 229 static constexpr const char* kAny = "any"; 230 // Failed serialized strings are represented with en empty string, therefore 231 // we use this string representation of an empty layout instead to avoid 232 // confusion. 233 static constexpr const char* kEmptyLayoutString = "empty_layout"; 234 // Used for the relayout operation, to allow relayout act as an identity on 235 // the layout for the given dimension. 236 static constexpr const char* kMatch = "match"; 237 238 // Returns empty layout. 239 static Layout Empty(); 240 241 // Parses from LayoutProto. 242 static StatusOr<Layout> FromProto(const LayoutProto& proto); 243 // Parses from a human readable string version of the layout, currently used 244 // to represent layouts in MLIR: 245 // layout = <sharding_specs:List[specs] mesh:name|List[MeshDim]| 246 // List[GlobalId]|List[LocalId]|List[Devices]> 247 // 248 // Example: 249 // layout = <sharding_specs:x,not_sharded mesh:name|x=2,y=2|0,1,2,3|0,1,2,3| 250 // /job:localhost/task:0/device:CPU:0,/job:localhost/task:0/device:CPU:1, 251 // /job:localhost/task:0/device:CPU:2,/job:localhost/task:0/device:CPU:3> 252 static StatusOr<Layout> FromString(std::string layout_str); 253 // Creates human readable string version of a layout. 254 std::string ToString() const; 255 LayoutProto ToProto() const; 256 mesh()257 const Mesh& mesh() const { return mesh_; } 258 static Layout ReplicatedOnMesh(const Mesh& mesh, int rank); 259 static Layout AnyOnMesh(const Mesh& mesh, int rank); 260 // Creates a mesh of unique shards. 261 Mesh ReducedMesh() const; set_mesh(Mesh mesh)262 void set_mesh(Mesh mesh) { mesh_ = mesh; } 263 264 // Returns a layout for the transposed matrix for given layout. This assumes 265 // that only the last two dimensions are used for matrix computation and all 266 // dimensions before are batch dimensions. 267 static StatusOr<Layout> Transposed2D(const Layout& layout); IsUnshardedDimension(const absl::string_view name)268 static bool IsUnshardedDimension(const absl::string_view name) { 269 return name == kUnshardedDim; 270 } IsShardedDimension(const absl::string_view name)271 static bool IsShardedDimension(const absl::string_view name) { 272 return !IsUnshardedDimension(name); 273 } IsUnshardedSpec(const ShardingSpec & spec)274 static bool IsUnshardedSpec(const ShardingSpec& spec) { 275 return IsUnshardedDimension(spec.sharding_spec()); 276 } IsShardedSpec(const ShardingSpec & spec)277 static bool IsShardedSpec(const ShardingSpec& spec) { 278 return !IsUnshardedDimension(spec.sharding_spec()); 279 } 280 static StatusOr<Layout> GetLayout( 281 const std::vector<std::string>& sharding_spec_strs, const Mesh& mesh); 282 static StatusOr<Layout> GetLayout( 283 const std::vector<ShardingSpec>& sharding_specs, const Mesh& mesh); 284 285 // Makes a new layout from this one dropping the given dimensions. 286 // If keep_dims is true, the dimensions are replicated rather than 287 // deleted. 288 Layout GetLayoutWithReducedDims(const absl::flat_hash_set<int>& reduced_dims, 289 bool keep_dims) const; 290 291 // Truncates a layout at the front or back, depending on the value of end. 292 // end = false returns the layout upto the split point, 293 // end = true returns the layout from the split point. 294 Layout Truncate(int64 split_point, bool end = false) const; 295 296 // Left or right pad the layout to a max rank. 297 Layout LeftPad(int64 rank) const; 298 299 bool IsFullyReplicated() const; 300 bool IsLastDimReplicated() const; 301 // Checks that the last N-1 dimensions are replicated 302 bool IsBatchParallel() const; 303 // Checks that the dimensions from [-non_batch_rank, end) are replicaed 304 bool IsBatchParallel(int non_batch_rank) const; 305 bool IsEmpty() const; 306 307 // Compute global shape using the layout and provided local_shape. 308 std::vector<int64_t> GlobalShapeFromLocalShape( 309 const std::vector<int64_t>& local_shape) const; 310 311 std::vector<int64_t> LocalShapeFromGlobalShape( 312 absl::Span<const int64_t> global_shape) const; 313 PartialTensorShape LocalShapeFromGlobalShape( 314 const PartialTensorShape& global_shape) const; 315 rank()316 int64 rank() const { return sharding_specs_.size(); } 317 size_t num_shards_for_dim(const ShardingSpec& dim) const; 318 std::vector<int32> num_shards() const; 319 dim(int64 idx)320 const ShardingSpec& dim(int64 idx) const { return sharding_specs_[idx]; } sharding_specs()321 absl::Span<const ShardingSpec> sharding_specs() const { 322 return sharding_specs_; 323 } 324 325 // Computes the corresponding shard vector to this layout. 326 ShardVector GetShardVector() const; 327 328 // Returns sharding specs in string form. 329 std::vector<std::string> sharding_spec_strs() const; 330 num_devices()331 int64 num_devices() const { return mesh_.num_devices(); } device_location(int64 device_id)332 StatusOr<const DeviceLocation> device_location(int64 device_id) const { 333 return mesh_.device_location(device_id); 334 } 335 // Map hosts to shards. 336 std::map<std::string, ShardVector> HostShardMap() const; 337 338 const std::string& sharding_spec(int idx) const; 339 340 // Two layouts are equivalent if they would result in the same sharding for 341 // the tensor. E.g. if on is unsharded and the other is sharded on a mesh 342 // dimension of size 1. 343 bool IsEquivalent(const Layout& b) const; 344 bool operator==(const Layout& b) const; 345 bool operator!=(const Layout& b) const { return !((*this) == b); } 346 bool operator<(const Layout& b) const { 347 return this->ToString() < b.ToString(); 348 } 349 350 private: 351 std::vector<ShardingSpec> sharding_specs_; 352 Mesh mesh_; 353 }; 354 355 // Takes two layouts and concatenates their TensorDimensions. If the meshes for 356 // the two layouts are different or both layouts are using the same mesh 357 // dimension returns an error rather than a layout. 358 StatusOr<Layout> ConcatenateLayouts(const Layout& layout_a, 359 const Layout& layout_b); 360 361 } // namespace dtensor 362 } // namespace tensorflow 363 364 #endif // TENSORFLOW_DTENSOR_CC_TENSOR_LAYOUT_H_ 365