1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/layout.h"
17
18 #include <string_view>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24
25 namespace xla {
26
ToProto() const27 TileProto Tile::ToProto() const {
28 TileProto tile_proto;
29 for (int64_t i : dimensions()) {
30 tile_proto.add_dimensions(i);
31 }
32 return tile_proto;
33 }
34
ToString() const35 std::string Tile::ToString() const {
36 std::vector<std::string> elements;
37 const auto& dims = dimensions();
38 elements.reserve(dims.size());
39 for (auto dim : dims) {
40 if (dim >= 0) {
41 elements.push_back(std::to_string(dim));
42 } else {
43 if (dim == kCombineDimension) {
44 elements.push_back("*");
45 } else {
46 elements.push_back(absl::StrCat("Invalid value ", dim));
47 }
48 }
49 }
50 return absl::StrCat("(", absl::StrJoin(elements, ","), ")");
51 }
52
CreateFromProto(const LayoutProto & proto)53 /* static */ Layout Layout::CreateFromProto(const LayoutProto& proto) {
54 Layout layout;
55 layout.minor_to_major_.reserve(proto.minor_to_major_size());
56 for (const int64_t dimension : proto.minor_to_major()) {
57 layout.add_minor_to_major(dimension);
58 }
59 for (const TileProto& tile_proto : proto.tiles()) {
60 *layout.add_tiles() = Tile::CreateFromProto(tile_proto);
61 }
62 layout.set_element_size_in_bits(proto.element_size_in_bits());
63 layout.set_memory_space(proto.memory_space());
64 return layout;
65 }
66
ToProto() const67 LayoutProto Layout::ToProto() const {
68 LayoutProto proto;
69 proto.mutable_minor_to_major()->Reserve(minor_to_major_size());
70 for (const int64_t dimension : minor_to_major()) {
71 proto.add_minor_to_major(dimension);
72 }
73 for (const Tile& tile : tiles()) {
74 *proto.add_tiles() = tile.ToProto();
75 }
76 proto.set_element_size_in_bits(element_size_in_bits());
77 proto.set_memory_space(memory_space_);
78 return proto;
79 }
80
81 namespace {
DimLevelTypeAbbrev(DimLevelType dim_level_type)82 absl::string_view DimLevelTypeAbbrev(DimLevelType dim_level_type) {
83 switch (dim_level_type) {
84 case DIM_DENSE:
85 return "D";
86 case DIM_COMPRESSED:
87 return "C";
88 case DIM_SINGLETON:
89 return "S";
90 default:
91 LOG(FATAL) << "Invalid DimLevelType value: " << dim_level_type;
92 }
93 }
94 } // namespace
95
ToString() const96 std::string Layout::ToString() const {
97 std::string colon_string;
98
99 if (!tiles().empty()) {
100 absl::StrAppend(&colon_string, "T");
101 for (const Tile& tile : tiles()) {
102 absl::StrAppend(&colon_string, tile.ToString());
103 }
104 }
105
106 if (!dim_level_types().empty()) {
107 absl::StrAppend(
108 &colon_string, "D(",
109 absl::StrJoin(dim_level_types(), ",",
110 [](std::string* out, DimLevelType dim_level_type) {
111 absl::StrAppend(out,
112 DimLevelTypeAbbrev(dim_level_type));
113 }),
114 ")");
115 }
116
117 if (element_size_in_bits() != 0) {
118 absl::StrAppend(&colon_string, "E(", element_size_in_bits(), ")");
119 }
120 if (memory_space() != 0) {
121 absl::StrAppend(&colon_string, "S(", memory_space(), ")");
122 }
123 return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","),
124 colon_string.empty() ? "" : ":", colon_string, "}");
125 }
126
operator ()(const Layout & lhs,const Layout & rhs)127 bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) {
128 if (!LayoutUtil::IsDense(lhs) || !LayoutUtil::IsDense(rhs)) {
129 if (lhs.dim_level_types() != rhs.dim_level_types()) {
130 return false;
131 }
132 }
133 if (lhs.minor_to_major() != rhs.minor_to_major()) {
134 return false;
135 }
136 if (!ignore_tiles_ && lhs.tiles() != rhs.tiles()) {
137 return false;
138 }
139 if (!ignore_element_size_ &&
140 lhs.element_size_in_bits() != rhs.element_size_in_bits()) {
141 return false;
142 }
143 if (!ignore_memory_space_ && lhs.memory_space() != rhs.memory_space()) {
144 return false;
145 }
146 return true;
147 }
148
operator ==(const Layout & other) const149 bool Layout::operator==(const Layout& other) const {
150 return Equal()(*this, other);
151 }
152
operator <<(std::ostream & out,const Tile & tile)153 std::ostream& operator<<(std::ostream& out, const Tile& tile) {
154 out << tile.ToString();
155 return out;
156 }
157
operator <<(std::ostream & out,const Layout & layout)158 std::ostream& operator<<(std::ostream& out, const Layout& layout) {
159 out << layout.ToString();
160 return out;
161 }
162
163 } // namespace xla
164