xref: /aosp_15_r20/external/executorch/extension/evalue_util/print_evalue.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/evalue_util/print_evalue.h>
10 
11 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12 
13 #include <algorithm>
14 #include <cmath>
15 #include <iomanip>
16 #include <ostream>
17 #include <sstream>
18 
19 using exec_aten::ScalarType;
20 
21 namespace executorch {
22 namespace extension {
23 
24 namespace {
25 
26 /// Number of list items on a line before wrapping.
27 constexpr size_t kItemsPerLine = 10;
28 
29 /// The default number of first/last list items to print before eliding.
30 constexpr size_t kDefaultEdgeItems = 3;
31 
32 /// Returns a globally unique "iword" index that we can use to store the current
33 /// "edge items" count on arbitrary streams.
get_edge_items_xalloc()34 int get_edge_items_xalloc() {
35   // Wrapping this in a function avoids a -Wglobal-constructors warning.
36   static const int xalloc = std::ios_base::xalloc();
37   return xalloc;
38 }
39 
40 /// Returns the number of "edge items" to print at the beginning and end of
41 /// lists when using the provided stream.
get_stream_edge_items(std::ostream & os)42 long get_stream_edge_items(std::ostream& os) {
43   long edge_items = os.iword(get_edge_items_xalloc());
44   return edge_items <= 0 ? kDefaultEdgeItems : edge_items;
45 }
46 
print_double(std::ostream & os,double value)47 void print_double(std::ostream& os, double value) {
48   if (std::isfinite(value)) {
49     // Mimic PyTorch by printing a trailing dot when the float value is
50     // integral, to distinguish from actual integers.
51     bool add_dot = false;
52     if (value == -0.0) {
53       // Special case that won't be detected by a comparison with int.
54       add_dot = true;
55     } else {
56       std::ostringstream oss_float;
57       oss_float << value;
58       std::ostringstream oss_int;
59       oss_int << static_cast<int64_t>(value);
60       if (oss_float.str() == oss_int.str()) {
61         add_dot = true;
62       }
63     }
64     if (add_dot) {
65       os << value << ".";
66     } else {
67       os << value;
68     }
69   } else {
70     // Infinity or NaN.
71     os << value;
72   }
73 }
74 
75 template <class T>
print_scalar_list(std::ostream & os,exec_aten::ArrayRef<T> list,bool print_length=true,bool elide_inner_items=true)76 void print_scalar_list(
77     std::ostream& os,
78     exec_aten::ArrayRef<T> list,
79     bool print_length = true,
80     bool elide_inner_items = true) {
81   long edge_items = elide_inner_items ? get_stream_edge_items(os)
82                                       : std::numeric_limits<long>::max();
83   if (print_length) {
84     os << "(len=" << list.size() << ")";
85   }
86 
87   // See if we'll be printing enough elements to cause us to wrap.
88   bool wrapping = false;
89   {
90     long num_printed_items;
91     if (elide_inner_items) {
92       num_printed_items =
93           std::min(static_cast<long>(list.size()), edge_items * 2);
94     } else {
95       num_printed_items = static_cast<long>(list.size());
96     }
97     wrapping = num_printed_items > kItemsPerLine;
98   }
99 
100   os << "[";
101   size_t num_printed = 0;
102   for (size_t i = 0; i < list.size(); ++i) {
103     if (wrapping && num_printed % kItemsPerLine == 0) {
104       // We've printed a full line, so wrap and begin a new one.
105       os << "\n  ";
106     }
107     os << executorch::runtime::EValue(exec_aten::Scalar(list[i]));
108     if (wrapping || i < list.size() - 1) {
109       // No trailing comma when not wrapping. Always a trailing comma when
110       // wrapping. This will leave a trailing space at the end of every wrapped
111       // line, but it simplifies the logic here.
112       os << ", ";
113     }
114     ++num_printed;
115     if (i + 1 == edge_items && i + edge_items + 1 < list.size()) {
116       if (wrapping) {
117         os << "\n  ...,";
118         // Make the first line after the elision be the ragged line, letting us
119         // always end on a full line.
120         num_printed = kItemsPerLine - edge_items % kItemsPerLine;
121         if (num_printed % kItemsPerLine != 0) {
122           // If the line ended exactly when the elision happened, the next
123           // iteration of the loop will add this line break.
124           os << "\n  ";
125         }
126       } else {
127         // Non-wrapping elision.
128         os << "..., ";
129       }
130       i = list.size() - edge_items - 1;
131     }
132   }
133   if (wrapping) {
134     // End the current line.
135     os << "\n";
136   }
137   os << "]";
138 }
139 
print_tensor(std::ostream & os,exec_aten::Tensor tensor)140 void print_tensor(std::ostream& os, exec_aten::Tensor tensor) {
141   os << "tensor(sizes=";
142   // Always print every element of the sizes list.
143   print_scalar_list(
144       os, tensor.sizes(), /*print_length=*/false, /*elide_inner_items=*/false);
145   os << ", ";
146 
147   // Print the data as a one-dimensional list.
148   //
149   // TODO(T159700776): Print dim_order and strides when they have non-default
150   // values.
151   //
152   // TODO(T159700776): Format multidimensional data like numpy/PyTorch does.
153   // https://github.com/pytorch/pytorch/blob/main/torch/_tensor_str.py
154 #define PRINT_TENSOR_DATA(ctype, dtype)                      \
155   case ScalarType::dtype:                                    \
156     print_scalar_list(                                       \
157         os,                                                  \
158         exec_aten::ArrayRef<ctype>(                          \
159             tensor.const_data_ptr<ctype>(), tensor.numel()), \
160         /*print_length=*/false);                             \
161     break;
162 
163   switch (tensor.scalar_type()) {
164     ET_FORALL_REAL_TYPES_AND2(Bool, Half, PRINT_TENSOR_DATA)
165     default:
166       os << "[<unhandled scalar type " << (int)tensor.scalar_type() << ">]";
167   }
168   os << ")";
169 
170 #undef PRINT_TENSOR_DATA
171 }
172 
print_tensor_list(std::ostream & os,exec_aten::ArrayRef<exec_aten::Tensor> list)173 void print_tensor_list(
174     std::ostream& os,
175     exec_aten::ArrayRef<exec_aten::Tensor> list) {
176   os << "(len=" << list.size() << ")[";
177   for (size_t i = 0; i < list.size(); ++i) {
178     if (list.size() > 1) {
179       os << "\n  [" << i << "]: ";
180     }
181     print_tensor(os, list[i]);
182     if (list.size() > 1) {
183       os << ",";
184     }
185   }
186   if (list.size() > 1) {
187     os << "\n";
188   }
189   os << "]";
190 }
191 
print_list_optional_tensor(std::ostream & os,exec_aten::ArrayRef<exec_aten::optional<exec_aten::Tensor>> list)192 void print_list_optional_tensor(
193     std::ostream& os,
194     exec_aten::ArrayRef<exec_aten::optional<exec_aten::Tensor>> list) {
195   os << "(len=" << list.size() << ")[";
196   for (size_t i = 0; i < list.size(); ++i) {
197     if (list.size() > 1) {
198       os << "\n  [" << i << "]: ";
199     }
200     if (list[i].has_value()) {
201       print_tensor(os, list[i].value());
202     } else {
203       os << "None";
204     }
205     if (list.size() > 1) {
206       os << ",";
207     }
208   }
209   if (list.size() > 1) {
210     os << "\n";
211   }
212   os << "]";
213 }
214 
215 } // namespace
216 
set_edge_items(std::ostream & os,long edge_items)217 void evalue_edge_items::set_edge_items(std::ostream& os, long edge_items) {
218   os.iword(get_edge_items_xalloc()) = edge_items;
219 }
220 
221 } // namespace extension
222 } // namespace executorch
223 
224 namespace executorch {
225 namespace runtime {
226 
227 // This needs to live in the same namespace as EValue.
operator <<(std::ostream & os,const EValue & value)228 std::ostream& operator<<(std::ostream& os, const EValue& value) {
229   using namespace executorch::extension;
230 
231   switch (value.tag) {
232     case Tag::None:
233       os << "None";
234       break;
235     case Tag::Bool:
236       if (value.toBool()) {
237         os << "True";
238       } else {
239         os << "False";
240       }
241       break;
242     case Tag::Int:
243       os << value.toInt();
244       break;
245     case Tag::Double:
246       print_double(os, value.toDouble());
247       break;
248     case Tag::String: {
249       auto str = value.toString();
250       os << std::quoted(std::string(str.data(), str.size()));
251     } break;
252     case Tag::Tensor:
253       print_tensor(os, value.toTensor());
254       break;
255     case Tag::ListBool:
256       print_scalar_list(os, value.toBoolList());
257       break;
258     case Tag::ListInt:
259       print_scalar_list(os, value.toIntList());
260       break;
261     case Tag::ListDouble:
262       print_scalar_list(os, value.toDoubleList());
263       break;
264     case Tag::ListTensor:
265       print_tensor_list(os, value.toTensorList());
266       break;
267     case Tag::ListOptionalTensor:
268       print_list_optional_tensor(os, value.toListOptionalTensor());
269       break;
270     default:
271       os << "<Unknown EValue tag " << static_cast<int>(value.tag) << ">";
272       break;
273   }
274   return os;
275 }
276 
277 } // namespace runtime
278 } // namespace executorch
279