xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/layout.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/layout.h"
17 
18 #include <string_view>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 
25 namespace xla {
26 
ToProto() const27 TileProto Tile::ToProto() const {
28   TileProto tile_proto;
29   for (int64_t i : dimensions()) {
30     tile_proto.add_dimensions(i);
31   }
32   return tile_proto;
33 }
34 
ToString() const35 std::string Tile::ToString() const {
36   std::vector<std::string> elements;
37   const auto& dims = dimensions();
38   elements.reserve(dims.size());
39   for (auto dim : dims) {
40     if (dim >= 0) {
41       elements.push_back(std::to_string(dim));
42     } else {
43       if (dim == kCombineDimension) {
44         elements.push_back("*");
45       } else {
46         elements.push_back(absl::StrCat("Invalid value ", dim));
47       }
48     }
49   }
50   return absl::StrCat("(", absl::StrJoin(elements, ","), ")");
51 }
52 
CreateFromProto(const LayoutProto & proto)53 /* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) {
54   Layout layout;
55   layout.minor_to_major_.reserve(proto.minor_to_major_size());
56   for (const int64_t dimension : proto.minor_to_major()) {
57     layout.add_minor_to_major(dimension);
58   }
59   for (const TileProto& tile_proto : proto.tiles()) {
60     *layout.add_tiles() = Tile::CreateFromProto(tile_proto);
61   }
62   layout.set_element_size_in_bits(proto.element_size_in_bits());
63   layout.set_memory_space(proto.memory_space());
64   return layout;
65 }
66 
ToProto() const67 LayoutProto Layout::ToProto() const {
68   LayoutProto proto;
69   proto.mutable_minor_to_major()->Reserve(minor_to_major_size());
70   for (const int64_t dimension : minor_to_major()) {
71     proto.add_minor_to_major(dimension);
72   }
73   for (const Tile& tile : tiles()) {
74     *proto.add_tiles() = tile.ToProto();
75   }
76   proto.set_element_size_in_bits(element_size_in_bits());
77   proto.set_memory_space(memory_space_);
78   return proto;
79 }
80 
81 namespace {
DimLevelTypeAbbrev(DimLevelType dim_level_type)82 absl::string_view DimLevelTypeAbbrev(DimLevelType dim_level_type) {
83   switch (dim_level_type) {
84     case DIM_DENSE:
85       return "D";
86     case DIM_COMPRESSED:
87       return "C";
88     case DIM_SINGLETON:
89       return "S";
90     default:
91       LOG(FATAL) << "Invalid DimLevelType value: " << dim_level_type;
92   }
93 }
94 }  // namespace
95 
ToString() const96 std::string Layout::ToString() const {
97   std::string colon_string;
98 
99   if (!tiles().empty()) {
100     absl::StrAppend(&colon_string, "T");
101     for (const Tile& tile : tiles()) {
102       absl::StrAppend(&colon_string, tile.ToString());
103     }
104   }
105 
106   if (!dim_level_types().empty()) {
107     absl::StrAppend(
108         &colon_string, "D(",
109         absl::StrJoin(dim_level_types(), ",",
110                       [](std::string* out, DimLevelType dim_level_type) {
111                         absl::StrAppend(out,
112                                         DimLevelTypeAbbrev(dim_level_type));
113                       }),
114         ")");
115   }
116 
117   if (element_size_in_bits() != 0) {
118     absl::StrAppend(&colon_string, "E(", element_size_in_bits(), ")");
119   }
120   if (memory_space() != 0) {
121     absl::StrAppend(&colon_string, "S(", memory_space(), ")");
122   }
123   return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","),
124                       colon_string.empty() ? "" : ":", colon_string, "}");
125 }
126 
operator ()(const Layout & lhs,const Layout & rhs)127 bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) {
128   if (!LayoutUtil::IsDense(lhs) || !LayoutUtil::IsDense(rhs)) {
129     if (lhs.dim_level_types() != rhs.dim_level_types()) {
130       return false;
131     }
132   }
133   if (lhs.minor_to_major() != rhs.minor_to_major()) {
134     return false;
135   }
136   if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) {
137     return false;
138   }
139   if (!ignore_element_size_ &&
140       lhs.element_size_in_bits() != rhs.element_size_in_bits()) {
141     return false;
142   }
143   if (!ignore_memory_space_ && lhs.memory_space() != rhs.memory_space()) {
144     return false;
145   }
146   return true;
147 }
148 
operator ==(const Layout & other) const149 bool Layout::operator==(const Layout& other) const {
150   return Equal()(*this, other);
151 }
152 
operator <<(std::ostream & out,const Tile & tile)153 std::ostream& operator<<(std::ostream& out, const Tile& tile) {
154   out << tile.ToString();
155   return out;
156 }
157 
operator <<(std::ostream & out,const Layout & layout)158 std::ostream& operator<<(std::ostream& out, const Layout& layout) {
159   out << layout.ToString();
160   return out;
161 }
162 
163 }  // namespace xla
164