1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/layout_util.h"
17
18 #include <stddef.h>
19
20 #include <algorithm>
21 #include <functional>
22 #include <numeric>
23 #include <random>
24 #include <string>
25 #include <vector>
26
27 #include "absl/hash/hash.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37
38 namespace xla {
39 namespace {
40
41 // Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
42 // minor_to_major to the value that represents the default layout.
43 template <typename T>
SetDefaultLayoutToContainer(T * minor_to_major)44 void SetDefaultLayoutToContainer(T* minor_to_major) {
45 // The default XLA layout is major-to-minor (dim 0 is major).
46 // For more information on XLA layouts, see:
47 // https://www.tensorflow.org/performance/xla/shapes
48 const int64_t size = minor_to_major->size();
49 for (int64_t i = 0; i < size; ++i) {
50 (*minor_to_major)[i] = size - 1 - i;
51 }
52 }
53
54 } // namespace
55
MakeLayout(absl::Span<const int64_t> minor_to_major,absl::Span<const DimLevelType> dim_level_types,absl::Span<const Tile> tiles,int64_t element_size_in_bits,int64_t memory_space)56 /* static */ Layout LayoutUtil::MakeLayout(
57 absl::Span<const int64_t> minor_to_major,
58 absl::Span<const DimLevelType> dim_level_types,
59 absl::Span<const Tile> tiles, int64_t element_size_in_bits,
60 int64_t memory_space) {
61 Layout layout;
62 for (int64_t dimension_number : minor_to_major) {
63 layout.add_minor_to_major(dimension_number);
64 }
65 for (DimLevelType dim_level_type : dim_level_types) {
66 layout.add_dim_level_type(dim_level_type);
67 }
68 for (const Tile& tile : tiles) {
69 for (int64_t dim : tile.dimensions()) {
70 if (dim < 0 && dim != Tile::kCombineDimension) {
71 LOG(FATAL)
72 << "Tile dimension size needs to be minimum int64_t value if "
73 "it's negative. Value is "
74 << dim;
75 }
76 }
77 *layout.add_tiles() = tile;
78 }
79 layout.set_element_size_in_bits(element_size_in_bits);
80 layout.set_memory_space(memory_space);
81 return layout;
82 }
83
MakeDescendingLayout(int64_t rank)84 /* static */ Layout LayoutUtil::MakeDescendingLayout(int64_t rank) {
85 std::vector<int64_t> layout(rank);
86 std::iota(layout.rbegin(), layout.rend(), static_cast<int64_t>(0));
87 return MakeLayout(layout);
88 }
89
MakeAscendingLayout(int64_t rank)90 /* static */ Layout LayoutUtil::MakeAscendingLayout(int64_t rank) {
91 std::vector<int64_t> layout(rank);
92 std::iota(layout.begin(), layout.end(), static_cast<int64_t>(0));
93 return MakeLayout(layout);
94 }
95
MakeLayoutFromMajorToMinor(absl::Span<const int64_t> major_to_minor)96 /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
97 absl::Span<const int64_t> major_to_minor) {
98 Layout layout;
99 for (int i = major_to_minor.size() - 1; i >= 0; i--) {
100 layout.add_minor_to_major(major_to_minor[i]);
101 }
102 return layout;
103 }
104
105 namespace {
106
107 // Internal helper that creates a default layout for an array of the given rank.
CreateDefaultLayoutForRank(int64_t rank)108 Layout CreateDefaultLayoutForRank(int64_t rank) {
109 Layout layout;
110 auto* minor_to_major = layout.mutable_minor_to_major();
111 minor_to_major->resize(rank, 0);
112 SetDefaultLayoutToContainer(minor_to_major);
113 return layout;
114 }
115
116 } // namespace
117
GetDefaultLayoutForShape(const Shape & shape)118 /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
119 if (shape.IsOpaque() || shape.IsToken()) {
120 // Opaque and token types have empty layouts.
121 return Layout();
122 }
123
124 // A Layout proto corresponds to a single array, not a tuple.
125 CHECK(shape.IsArray());
126 return CreateDefaultLayoutForRank(shape.dimensions_size());
127 }
128
GetDefaultLayoutForRank(int64_t rank)129 /* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64_t rank) {
130 return CreateDefaultLayoutForRank(rank);
131 }
132
GetDefaultLayoutForR2()133 /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
134 return CreateDefaultLayoutForRank(2);
135 }
136
GetDefaultLayoutForR3()137 /* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
138 return CreateDefaultLayoutForRank(3);
139 }
140
GetDefaultLayoutForR4()141 /* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
142 return CreateDefaultLayoutForRank(4);
143 }
144
SetToDefaultLayout(Shape * shape)145 /* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
146 if (shape->IsTuple()) {
147 // Tuple shape.
148 for (auto& element_shape : *shape->mutable_tuple_shapes()) {
149 SetToDefaultLayout(&element_shape);
150 }
151 shape->clear_layout();
152 } else if (shape->IsArray()) {
153 auto* minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
154 minor_to_major->resize(shape->dimensions_size(), 0);
155 SetDefaultLayoutToContainer(minor_to_major);
156 } else {
157 // Opaque, token types etc. have no layout.
158 shape->clear_layout();
159 }
160 }
161
GetWithDefaultLayout(const Shape & shape)162 /* static */ Shape LayoutUtil::GetWithDefaultLayout(const Shape& shape) {
163 Shape copy(shape);
164 LayoutUtil::SetToDefaultLayout(©);
165 return copy;
166 }
167
SetToDefaultLayout(ProgramShape * program_shape)168 /* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
169 for (auto& parameter_shape : *program_shape->mutable_parameters()) {
170 LayoutUtil::SetToDefaultLayout(¶meter_shape);
171 }
172 LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
173 }
174
ValidateLayoutInShape(const Shape & shape,bool allow_missing_layouts)175 /* static */ Status LayoutUtil::ValidateLayoutInShape(
176 const Shape& shape, bool allow_missing_layouts) {
177 if (shape.IsTuple()) {
178 // Tuple shape.
179 if (shape.has_layout()) {
180 return InvalidArgument("tuple should not have a layout field");
181 }
182 for (auto& element_shape : shape.tuple_shapes()) {
183 TF_RETURN_IF_ERROR(
184 ValidateLayoutInShape(element_shape, allow_missing_layouts));
185 }
186 return OkStatus();
187 } else if (shape.IsArray()) {
188 if (!shape.has_layout()) {
189 if (allow_missing_layouts) {
190 return OkStatus();
191 }
192 return InvalidArgument("shape %s does not have a layout",
193 ShapeUtil::HumanString(shape));
194 }
195 return ValidateLayoutForShape(shape.layout(), shape);
196 } else {
197 // Token, opaque, etc. shape.
198 if (shape.has_layout()) {
199 return InvalidArgument(
200 "shape of primitive type %s should not have a layout",
201 PrimitiveType_Name(shape.element_type()));
202 }
203 return OkStatus();
204 }
205 }
206
ValidateLayoutForShape(const Layout & layout,const Shape & shape)207 /* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout,
208 const Shape& shape) {
209 if (shape.IsTuple()) {
210 return InvalidArgument("a single Layout is not valid for tuple shapes");
211 }
212
213 if (!shape.IsArray()) {
214 if (layout.minor_to_major_size() != 0) {
215 return InvalidArgument(
216 "shape of primitive type %s should not have a non-trivial layout",
217 PrimitiveType_Name(shape.element_type()));
218 }
219 return OkStatus();
220 }
221
222 if (layout.minor_to_major_size() != shape.rank()) {
223 return InvalidArgument(
224 "layout minor_to_major field contains %d elements, "
225 "but shape is rank %d: {%s}; shape: %s",
226 layout.minor_to_major_size(), shape.rank(),
227 absl::StrJoin(layout.minor_to_major(), ", "), shape.ShortDebugString());
228 }
229
230 std::vector<bool> dimensions_in_layout(shape.rank(), false);
231 for (int64_t i = 0; i < shape.rank(); ++i) {
232 int64_t dim = layout.minor_to_major(i);
233 if (dim < 0 || dim >= shape.rank()) {
234 return InvalidArgument(
235 "layout minor_to_major field has out-of-bounds value: %s",
236 HumanString(layout));
237 }
238 if (dimensions_in_layout[dim]) {
239 return InvalidArgument(
240 "layout minor_to_major field has duplicate values: {%s}",
241 HumanString(layout));
242 }
243 dimensions_in_layout[dim] = true;
244 }
245
246 if (!layout.dim_level_types().empty()) {
247 if (layout.dim_level_types().size() != shape.rank()) {
248 return InvalidArgument(
249 "layout dim_level_types field contains %d elements, but shape is "
250 "rank %d: {%s}; shape: %s",
251 layout.dim_level_types_size(), shape.rank(),
252 absl::StrJoin(layout.dim_level_types(), ", ",
253 [](std::string* out, DimLevelType dim_level_type) {
254 absl::StrAppend(out,
255 DimLevelType_Name(dim_level_type));
256 }),
257 shape.ShortDebugString());
258 }
259 if (LayoutUtil::IsSparse(layout)) {
260 if (layout.tiles_size() > 0) {
261 return InvalidArgument(
262 "layout has tiles, but the shape is a sparse array: %s",
263 shape.ShortDebugString());
264 }
265 }
266 }
267
268 return OkStatus();
269 }
270
ClearLayout(Shape * shape)271 /* static */ void LayoutUtil::ClearLayout(Shape* shape) {
272 shape->clear_layout();
273 for (auto& element_shape : *shape->mutable_tuple_shapes()) {
274 ClearLayout(&element_shape);
275 }
276 }
277
ClearLayout(ProgramShape * program_shape)278 /* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
279 for (auto& parameter_shape : *program_shape->mutable_parameters()) {
280 LayoutUtil::ClearLayout(¶meter_shape);
281 }
282 LayoutUtil::ClearLayout(program_shape->mutable_result());
283 }
284
ClearTiles(Shape * shape)285 /* static */ void LayoutUtil::ClearTiles(Shape* shape) {
286 ShapeUtil::ForEachMutableSubshape(
287 shape, [](Shape* subshape, const ShapeIndex&) {
288 if (subshape->has_layout()) {
289 if (subshape->has_layout()) {
290 subshape->mutable_layout()->clear_tiles();
291 }
292 }
293 });
294 }
295
IsDenseArray(const Shape & shape)296 /* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) {
297 return shape.IsArray() && (!shape.has_layout() || IsDense(shape.layout()));
298 }
299
IsSparseArray(const Shape & shape)300 /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) {
301 return shape.IsArray() && shape.has_layout() && IsSparse(shape.layout());
302 }
303
IsCOOArray(const Shape & shape)304 /* static */ bool LayoutUtil::IsCOOArray(const Shape& shape) {
305 return shape.IsArray() && shape.has_layout() && IsCOO(shape.layout());
306 }
307
IsCSRArray(const Shape & shape)308 /* static */ bool LayoutUtil::IsCSRArray(const Shape& shape) {
309 return shape.IsArray() && shape.rank() == 2 && shape.has_layout() &&
310 IsCSR(shape.layout());
311 }
312
IsCSCArray(const Shape & shape)313 /* static */ bool LayoutUtil::IsCSCArray(const Shape& shape) {
314 return shape.IsArray() && shape.rank() == 2 && shape.has_layout() &&
315 IsCSC(shape.layout());
316 }
317
IsDense(const Layout & layout)318 /* static */ bool LayoutUtil::IsDense(const Layout& layout) {
319 return absl::c_all_of(
320 layout.dim_level_types(),
321 [](DimLevelType dim_level_type) { return dim_level_type == DIM_DENSE; });
322 }
323
IsSparse(const Layout & layout)324 /* static */ bool LayoutUtil::IsSparse(const Layout& layout) {
325 return !IsDense(layout);
326 }
327
IsCOO(const Layout & layout)328 /* static */ bool LayoutUtil::IsCOO(const Layout& layout) {
329 return !layout.dim_level_types().empty() &&
330 layout.dim_level_type(0) == DIM_COMPRESSED &&
331 absl::c_all_of(layout.dim_level_types().subspan(1),
332 [](DimLevelType dim_level_type) {
333 return dim_level_type == DIM_SINGLETON;
334 });
335 }
336
IsCSR(const Layout & layout)337 /* static */ bool LayoutUtil::IsCSR(const Layout& layout) {
338 return IsMonotonicWithDim0Major(layout) &&
339 layout.dim_level_types() ==
340 absl::Span<const DimLevelType>{DIM_DENSE, DIM_COMPRESSED};
341 }
342
IsCSC(const Layout & layout)343 /* static */ bool LayoutUtil::IsCSC(const Layout& layout) {
344 return IsMonotonicWithDim0Minor(layout) &&
345 layout.dim_level_types() ==
346 absl::Span<const DimLevelType>{DIM_DENSE, DIM_COMPRESSED};
347 }
348
IsMonotonicWithDim0Minor(const Layout & layout)349 /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
350 return std::is_sorted(layout.minor_to_major().begin(),
351 layout.minor_to_major().end());
352 }
353
IsMonotonicWithDim0Major(const Layout & layout)354 /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
355 return std::is_sorted(layout.minor_to_major().begin(),
356 layout.minor_to_major().end(), std::greater<int64_t>());
357 }
358
HasLayout(const Shape & shape)359 /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
360 if (shape.IsTuple()) {
361 // Tuple shape: all subshapes must have a layout.
362 return absl::c_all_of(shape.tuple_shapes(),
363 [](const Shape& s) { return HasLayout(s); });
364 } else if (!shape.IsArray()) {
365 // Opaque, token types etc. ignore layout.
366 return true;
367 }
368 return shape.has_layout();
369 }
370
HasLayout(const ProgramShape & program_shape)371 /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
372 for (auto& parameter_shape : program_shape.parameters()) {
373 if (!LayoutUtil::HasLayout(parameter_shape)) {
374 return false;
375 }
376 }
377 return LayoutUtil::HasLayout(program_shape.result());
378 }
379
Equal(const Layout & lhs,const Layout & rhs)380 /* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
381 return lhs == rhs;
382 }
383
MinorToMajor(const Shape & shape)384 /* static */ absl::Span<const int64_t> LayoutUtil::MinorToMajor(
385 const Shape& shape) {
386 CHECK(shape.IsArray());
387 return shape.layout().minor_to_major();
388 }
389
MinorToMajor(const Layout & layout)390 /* static */ absl::Span<const int64_t> LayoutUtil::MinorToMajor(
391 const Layout& layout) {
392 return layout.minor_to_major();
393 }
394
Major(const Layout & layout,int64_t physical_dimension_number)395 /* static */ int64_t LayoutUtil::Major(const Layout& layout,
396 int64_t physical_dimension_number) {
397 CHECK_LE(0, physical_dimension_number);
398 CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
399 return Minor(layout,
400 layout.minor_to_major_size() - 1 - physical_dimension_number);
401 }
402
Minor(const Layout & layout,int64_t physical_dimension_number)403 /* static */ int64_t LayoutUtil::Minor(const Layout& layout,
404 int64_t physical_dimension_number) {
405 CHECK_LE(0, physical_dimension_number);
406 CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
407 return layout.minor_to_major(physical_dimension_number);
408 }
409
MakeLogicalToPhysical(const Layout & layout)410 /* static */ std::vector<int64_t> LayoutUtil::MakeLogicalToPhysical(
411 const Layout& layout) {
412 std::vector<int64_t> logical_to_physical(layout.minor_to_major_size());
413 for (int64_t physical = 0, end = logical_to_physical.size(); physical < end;
414 ++physical) {
415 const int64_t logical = Major(layout, physical);
416 logical_to_physical[logical] = physical;
417 }
418 return logical_to_physical;
419 }
420
HumanString(const Layout & layout)421 /* static */ std::string LayoutUtil::HumanString(const Layout& layout) {
422 return layout.ToString();
423 }
424
425 namespace {
426
427 // Internal helper for recursively copying layouts.
CopyLayoutInternal(const Shape & src,Shape * dst)428 Status CopyLayoutInternal(const Shape& src, Shape* dst) {
429 if (src.IsTuple() != dst->IsTuple()) {
430 return InvalidArgument(
431 "cannot copy layout from shape: shape structure differs");
432 }
433 if (src.IsTuple()) {
434 if (ShapeUtil::TupleElementCount(src) !=
435 ShapeUtil::TupleElementCount(*dst)) {
436 return InvalidArgument(
437 "cannot copy layout from shape: tuple element count differs");
438 }
439 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
440 TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
441 dst->mutable_tuple_shapes(i)));
442 }
443 } else {
444 if (src.has_layout()) {
445 if (src.rank() != dst->rank()) {
446 return InvalidArgument("cannot copy layout from shape: ranks differs");
447 }
448 TF_RETURN_IF_ERROR(
449 LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
450 *dst->mutable_layout() = src.layout();
451 } else {
452 dst->clear_layout();
453 }
454 }
455 return OkStatus();
456 }
457
458 } // namespace
459
460 /* static */
CopyLayoutBetweenShapes(const Shape & src,Shape * dst)461 Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
462 return CopyLayoutInternal(src, dst);
463 }
464
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)465 /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
466 const Shape& rhs) {
467 if (lhs.IsTuple()) {
468 if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
469 ShapeUtil::TupleElementCount(rhs)) {
470 return false;
471 }
472 for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
473 if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
474 return false;
475 }
476 }
477 return true;
478 }
479 if (lhs.IsArray()) {
480 if (lhs.rank() != rhs.rank()) {
481 return false;
482 }
483 if (!lhs.has_layout() && !rhs.has_layout()) {
484 return true;
485 }
486 if (!lhs.has_layout() || !rhs.has_layout()) {
487 return false;
488 }
489 return LayoutUtil::Equal(lhs.layout(), rhs.layout());
490 }
491 // Layouts of non-array and non-tuple shapes is ignored.
492 return true;
493 }
494
AreDimensionsConsecutive(const Layout & layout,absl::Span<const int64_t> dims)495 /* static */ bool LayoutUtil::AreDimensionsConsecutive(
496 const Layout& layout, absl::Span<const int64_t> dims) {
497 absl::InlinedVector<int64_t, 8> positions_in_layout;
498 for (int64_t dim : dims) {
499 positions_in_layout.push_back(
500 PositionInContainer(layout.minor_to_major(), dim));
501 }
502 absl::c_sort(positions_in_layout);
503 for (size_t i = 1; i < positions_in_layout.size(); ++i) {
504 if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
505 return false;
506 }
507 }
508 return true;
509 }
510
MoveDimToMajor(const Layout & layout,int64_t dim)511 /*static*/ Layout LayoutUtil::MoveDimToMajor(const Layout& layout,
512 int64_t dim) {
513 if (dim == MinorToMajor(layout).back()) return layout;
514 Layout ret = layout;
515 ret.clear_minor_to_major();
516 for (auto d : MinorToMajor(layout)) {
517 if (d != dim) {
518 ret.add_minor_to_major(d);
519 }
520 }
521 ret.add_minor_to_major(dim);
522 return ret;
523 }
524
LinearIndex(const Shape & shape,absl::Span<const int64_t> indices)525 /*static*/ int64_t LayoutUtil::LinearIndex(const Shape& shape,
526 absl::Span<const int64_t> indices) {
527 CHECK(shape.IsArray());
528 CHECK(shape.has_layout());
529 const int rank = shape.rank();
530 CHECK_EQ(rank, indices.size());
531
532 if (rank == 0) {
533 return 0;
534 }
535 if (rank == 1) {
536 return indices[0];
537 }
538
539 Tile tile = {};
540 if (!shape.layout().tiles().empty()) {
541 tile = shape.layout().tiles()[0];
542 }
543
544 int64_t linear_index = 0;
545 int64_t tile_multiplier = 1;
546 // Initialize to number of elements in a tile.
547 for (int64_t i : tile.dimensions()) {
548 tile_multiplier *= i;
549 }
550 int64_t within_tile_multiplier = 1;
551
552 // We only look at the top-level tile.
553 for (int64_t minor = 0; minor < rank; minor++) {
554 int64_t logical_dim = Minor(shape.layout(), minor);
555 int64_t shape_dim_size = shape.dimensions(logical_dim);
556 int64_t index = indices[logical_dim];
557
558 if (minor < tile.dimensions().size()) {
559 int64_t tile_dim_size =
560 tile.dimensions()[tile.dimensions().size() - 1 - minor];
561 linear_index += tile_multiplier * (index / tile_dim_size) +
562 within_tile_multiplier * (index % tile_dim_size);
563 tile_multiplier *= CeilOfRatio(shape_dim_size, tile_dim_size);
564 within_tile_multiplier *= tile_dim_size;
565 } else {
566 linear_index += index * tile_multiplier;
567 tile_multiplier *= shape_dim_size;
568 }
569 }
570 return linear_index;
571 }
572
MemorySpace(const Shape & shape)573 /*static*/ int64_t LayoutUtil::MemorySpace(const Shape& shape) {
574 return shape.has_layout() ? shape.layout().memory_space()
575 : Layout::kDefaultMemorySpace;
576 }
577
578 } // namespace xla
579