1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/evalue_util/print_evalue.h>
10
11 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12
13 #include <algorithm>
14 #include <cmath>
15 #include <iomanip>
16 #include <ostream>
17 #include <sstream>
18
19 using exec_aten::ScalarType;
20
21 namespace executorch {
22 namespace extension {
23
24 namespace {
25
26 /// Number of list items on a line before wrapping.
27 constexpr size_t kItemsPerLine = 10;
28
29 /// The default number of first/last list items to print before eliding.
30 constexpr size_t kDefaultEdgeItems = 3;
31
32 /// Returns a globally unique "iword" index that we can use to store the current
33 /// "edge items" count on arbitrary streams.
get_edge_items_xalloc()34 int get_edge_items_xalloc() {
35 // Wrapping this in a function avoids a -Wglobal-constructors warning.
36 static const int xalloc = std::ios_base::xalloc();
37 return xalloc;
38 }
39
40 /// Returns the number of "edge items" to print at the beginning and end of
41 /// lists when using the provided stream.
get_stream_edge_items(std::ostream & os)42 long get_stream_edge_items(std::ostream& os) {
43 long edge_items = os.iword(get_edge_items_xalloc());
44 return edge_items <= 0 ? kDefaultEdgeItems : edge_items;
45 }
46
print_double(std::ostream & os,double value)47 void print_double(std::ostream& os, double value) {
48 if (std::isfinite(value)) {
49 // Mimic PyTorch by printing a trailing dot when the float value is
50 // integral, to distinguish from actual integers.
51 bool add_dot = false;
52 if (value == -0.0) {
53 // Special case that won't be detected by a comparison with int.
54 add_dot = true;
55 } else {
56 std::ostringstream oss_float;
57 oss_float << value;
58 std::ostringstream oss_int;
59 oss_int << static_cast<int64_t>(value);
60 if (oss_float.str() == oss_int.str()) {
61 add_dot = true;
62 }
63 }
64 if (add_dot) {
65 os << value << ".";
66 } else {
67 os << value;
68 }
69 } else {
70 // Infinity or NaN.
71 os << value;
72 }
73 }
74
75 template <class T>
print_scalar_list(std::ostream & os,exec_aten::ArrayRef<T> list,bool print_length=true,bool elide_inner_items=true)76 void print_scalar_list(
77 std::ostream& os,
78 exec_aten::ArrayRef<T> list,
79 bool print_length = true,
80 bool elide_inner_items = true) {
81 long edge_items = elide_inner_items ? get_stream_edge_items(os)
82 : std::numeric_limits<long>::max();
83 if (print_length) {
84 os << "(len=" << list.size() << ")";
85 }
86
87 // See if we'll be printing enough elements to cause us to wrap.
88 bool wrapping = false;
89 {
90 long num_printed_items;
91 if (elide_inner_items) {
92 num_printed_items =
93 std::min(static_cast<long>(list.size()), edge_items * 2);
94 } else {
95 num_printed_items = static_cast<long>(list.size());
96 }
97 wrapping = num_printed_items > kItemsPerLine;
98 }
99
100 os << "[";
101 size_t num_printed = 0;
102 for (size_t i = 0; i < list.size(); ++i) {
103 if (wrapping && num_printed % kItemsPerLine == 0) {
104 // We've printed a full line, so wrap and begin a new one.
105 os << "\n ";
106 }
107 os << executorch::runtime::EValue(exec_aten::Scalar(list[i]));
108 if (wrapping || i < list.size() - 1) {
109 // No trailing comma when not wrapping. Always a trailing comma when
110 // wrapping. This will leave a trailing space at the end of every wrapped
111 // line, but it simplifies the logic here.
112 os << ", ";
113 }
114 ++num_printed;
115 if (i + 1 == edge_items && i + edge_items + 1 < list.size()) {
116 if (wrapping) {
117 os << "\n ...,";
118 // Make the first line after the elision be the ragged line, letting us
119 // always end on a full line.
120 num_printed = kItemsPerLine - edge_items % kItemsPerLine;
121 if (num_printed % kItemsPerLine != 0) {
122 // If the line ended exactly when the elision happened, the next
123 // iteration of the loop will add this line break.
124 os << "\n ";
125 }
126 } else {
127 // Non-wrapping elision.
128 os << "..., ";
129 }
130 i = list.size() - edge_items - 1;
131 }
132 }
133 if (wrapping) {
134 // End the current line.
135 os << "\n";
136 }
137 os << "]";
138 }
139
print_tensor(std::ostream & os,exec_aten::Tensor tensor)140 void print_tensor(std::ostream& os, exec_aten::Tensor tensor) {
141 os << "tensor(sizes=";
142 // Always print every element of the sizes list.
143 print_scalar_list(
144 os, tensor.sizes(), /*print_length=*/false, /*elide_inner_items=*/false);
145 os << ", ";
146
147 // Print the data as a one-dimensional list.
148 //
149 // TODO(T159700776): Print dim_order and strides when they have non-default
150 // values.
151 //
152 // TODO(T159700776): Format multidimensional data like numpy/PyTorch does.
153 // https://github.com/pytorch/pytorch/blob/main/torch/_tensor_str.py
154 #define PRINT_TENSOR_DATA(ctype, dtype) \
155 case ScalarType::dtype: \
156 print_scalar_list( \
157 os, \
158 exec_aten::ArrayRef<ctype>( \
159 tensor.const_data_ptr<ctype>(), tensor.numel()), \
160 /*print_length=*/false); \
161 break;
162
163 switch (tensor.scalar_type()) {
164 ET_FORALL_REAL_TYPES_AND2(Bool, Half, PRINT_TENSOR_DATA)
165 default:
166 os << "[<unhandled scalar type " << (int)tensor.scalar_type() << ">]";
167 }
168 os << ")";
169
170 #undef PRINT_TENSOR_DATA
171 }
172
print_tensor_list(std::ostream & os,exec_aten::ArrayRef<exec_aten::Tensor> list)173 void print_tensor_list(
174 std::ostream& os,
175 exec_aten::ArrayRef<exec_aten::Tensor> list) {
176 os << "(len=" << list.size() << ")[";
177 for (size_t i = 0; i < list.size(); ++i) {
178 if (list.size() > 1) {
179 os << "\n [" << i << "]: ";
180 }
181 print_tensor(os, list[i]);
182 if (list.size() > 1) {
183 os << ",";
184 }
185 }
186 if (list.size() > 1) {
187 os << "\n";
188 }
189 os << "]";
190 }
191
print_list_optional_tensor(std::ostream & os,exec_aten::ArrayRef<exec_aten::optional<exec_aten::Tensor>> list)192 void print_list_optional_tensor(
193 std::ostream& os,
194 exec_aten::ArrayRef<exec_aten::optional<exec_aten::Tensor>> list) {
195 os << "(len=" << list.size() << ")[";
196 for (size_t i = 0; i < list.size(); ++i) {
197 if (list.size() > 1) {
198 os << "\n [" << i << "]: ";
199 }
200 if (list[i].has_value()) {
201 print_tensor(os, list[i].value());
202 } else {
203 os << "None";
204 }
205 if (list.size() > 1) {
206 os << ",";
207 }
208 }
209 if (list.size() > 1) {
210 os << "\n";
211 }
212 os << "]";
213 }
214
215 } // namespace
216
set_edge_items(std::ostream & os,long edge_items)217 void evalue_edge_items::set_edge_items(std::ostream& os, long edge_items) {
218 os.iword(get_edge_items_xalloc()) = edge_items;
219 }
220
221 } // namespace extension
222 } // namespace executorch
223
224 namespace executorch {
225 namespace runtime {
226
227 // This needs to live in the same namespace as EValue.
operator <<(std::ostream & os,const EValue & value)228 std::ostream& operator<<(std::ostream& os, const EValue& value) {
229 using namespace executorch::extension;
230
231 switch (value.tag) {
232 case Tag::None:
233 os << "None";
234 break;
235 case Tag::Bool:
236 if (value.toBool()) {
237 os << "True";
238 } else {
239 os << "False";
240 }
241 break;
242 case Tag::Int:
243 os << value.toInt();
244 break;
245 case Tag::Double:
246 print_double(os, value.toDouble());
247 break;
248 case Tag::String: {
249 auto str = value.toString();
250 os << std::quoted(std::string(str.data(), str.size()));
251 } break;
252 case Tag::Tensor:
253 print_tensor(os, value.toTensor());
254 break;
255 case Tag::ListBool:
256 print_scalar_list(os, value.toBoolList());
257 break;
258 case Tag::ListInt:
259 print_scalar_list(os, value.toIntList());
260 break;
261 case Tag::ListDouble:
262 print_scalar_list(os, value.toDoubleList());
263 break;
264 case Tag::ListTensor:
265 print_tensor_list(os, value.toTensorList());
266 break;
267 case Tag::ListOptionalTensor:
268 print_list_optional_tensor(os, value.toListOptionalTensor());
269 break;
270 default:
271 os << "<Unknown EValue tag " << static_cast<int>(value.tag) << ">";
272 break;
273 }
274 return os;
275 }
276
277 } // namespace runtime
278 } // namespace executorch
279