xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/shape_layout.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_LAYOUT_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_LAYOUT_H_
18 
19 #include <string>
20 
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/lib/core/status.h"
25 
26 namespace xla {
27 
28 // A ShapeLayout object encapsulates the layout of a particular shape (including
29 // tuples). This differs from the Layout proto which describes the layout of a
30 // single array. ShapeLayout contains a Layout proto for each array in the shape
31 // (a tuple can have more than one array). For array shapes, this object
32 // trivially holds a single Layout. Logically, ShapeLayout holds a nonmutable
33 // shape with mutable layouts.
34 class ShapeLayout {
35  public:
36   // Constructs a ShapeLayout of the given shape. Layouts are copied from the
37   // shape parameter.
ShapeLayout(const Shape & shape)38   explicit ShapeLayout(const Shape& shape) : shape_(shape) {}
39 
40   // Assigns the layouts in this ShapeLayout to the Layout fields of the given
41   // shape. 'to_shape' and the shape of the ShapeLayout object must be
42   // compatible.
43   Status AssignLayoutToShape(Shape* to_shape) const;
44 
45   // Returns true if the Layouts in this ShapeLayout match the layouts in the
46   // given shape. Returns false otherwise. If the given shape is not compatible
47   // with the ShapeLayout's shape, then false is returned. If
48   // `ignore_fully_empty_tiling` is true, tiling info is ignored if one of the
49   // shapes has no tiling at all in all its subshapes.
50   bool MatchesLayoutInShape(const Shape& shape,
51                             bool minor_to_major_only = false,
52                             bool ignore_fully_empty_tiling = false) const;
53 
54   // Copies the layout from the given shape into this ShapeLayout. 'other_shape'
55   // must be compatible with the ShapeLayout's shape.
56   Status CopyLayoutFromShape(const Shape& other_shape);
57 
58   // Clears (Layout::Clear) all the Layouts stored in this object.
59   void Clear();
60 
61   // Sets all Layouts stored in this object to the default layout.
62   void SetToDefaultLayout();
63 
64   // Returns the shape (with layouts).
shape()65   const Shape& shape() const { return shape_; }
66 
67   // Clear dynamic dimensions of this module. Pretending the module creates
68   // static results. Useful in inspecting full outputs when testing.
ClearDynamicShape()69   void ClearDynamicShape() { shape_.clear_dynamic_dimensions(); }
70 
71   // Checks that a layout is set for the shape, and returns a reference to the
72   // layout directly on the shape. Shape must not be a tuple.
73   const Layout& layout() const;
74 
75   // Returns true if all layouts have been set for this ShapeLayout object. That
76   // is, every array has a layout.
77   bool LayoutIsSet() const;
78 
79   // Resets the layout on the shape to the provided layout. Shape must not be a
80   // tuple.
81   void ResetLayout(const Layout& layout);
82 
83   // Resets the layout on the shape at the provided ShapeIndex to the provided
84   // layout. Shape must be a tuple.
85   void ResetLayout(const Layout& layout, ShapeIndexView shape_index);
86 
87   // Returns a string representation of this object.
ToString()88   std::string ToString() const { return shape_.ToString(true); }
89 
90   // Tests for equality of both shape and layout (ShapeUtil::Equal).
91   bool operator==(const ShapeLayout& other) const;
92   bool operator!=(const ShapeLayout& other) const;
93 
94  private:
95   Shape shape_;
96 };
97 
98 }  // namespace xla
99 
100 #endif  // TENSORFLOW_COMPILER_XLA_SHAPE_LAYOUT_H_
101