xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/onnx.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/serialization/onnx.h>
3 #include <torch/csrc/onnx/onnx.h>
4 
5 #include <sstream>
6 #include <string>
7 
8 namespace torch::jit {
9 
10 namespace {
11 namespace onnx = ::ONNX_NAMESPACE;
12 
13 // Pretty printing for ONNX
14 constexpr char indent_char = ' ';
15 constexpr size_t indent_multiplier = 2;
16 
idt(size_t indent)17 std::string idt(size_t indent) {
18   return std::string(indent * indent_multiplier, indent_char);
19 }
20 
nlidt(size_t indent)21 std::string nlidt(size_t indent) {
22   return std::string("\n") + idt(indent);
23 }
24 
dump(const onnx::TensorProto & tensor,std::ostream & stream)25 void dump(const onnx::TensorProto& tensor, std::ostream& stream) {
26   stream << "TensorProto shape: [";
27   for (const auto i : c10::irange(tensor.dims_size())) {
28     stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ? "" : " ");
29   }
30   stream << "]";
31 }
32 
dump(const onnx::TensorShapeProto & shape,std::ostream & stream)33 void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
34   for (const auto i : c10::irange(shape.dim_size())) {
35     auto& dim = shape.dim(i);
36     if (dim.has_dim_value()) {
37       stream << dim.dim_value();
38     } else {
39       stream << "?";
40     }
41     stream << (i == shape.dim_size() - 1 ? "" : " ");
42   }
43 }
44 
dump(const onnx::TypeProto_Tensor & tensor_type,std::ostream & stream)45 void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) {
46   stream << "Tensor dtype: ";
47   if (tensor_type.has_elem_type()) {
48     stream << tensor_type.elem_type();
49   } else {
50     stream << "None.";
51   }
52   stream << ", ";
53   stream << "Tensor dims: ";
54   if (tensor_type.has_shape()) {
55     dump(tensor_type.shape(), stream);
56   } else {
57     stream << "None.";
58   }
59 }
60 
61 void dump(const onnx::TypeProto& type, std::ostream& stream);
62 
dump(const onnx::TypeProto_Optional & optional_type,std::ostream & stream)63 void dump(const onnx::TypeProto_Optional& optional_type, std::ostream& stream) {
64   stream << "Optional<";
65   if (optional_type.has_elem_type()) {
66     dump(optional_type.elem_type(), stream);
67   } else {
68     stream << "None";
69   }
70   stream << ">";
71 }
72 
dump(const onnx::TypeProto_Sequence & sequence_type,std::ostream & stream)73 void dump(const onnx::TypeProto_Sequence& sequence_type, std::ostream& stream) {
74   stream << "Sequence<";
75   if (sequence_type.has_elem_type()) {
76     dump(sequence_type.elem_type(), stream);
77   } else {
78     stream << "None";
79   }
80   stream << ">";
81 }
82 
dump(const onnx::TypeProto & type,std::ostream & stream)83 void dump(const onnx::TypeProto& type, std::ostream& stream) {
84   if (type.has_tensor_type()) {
85     dump(type.tensor_type(), stream);
86   } else if (type.has_sequence_type()) {
87     dump(type.sequence_type(), stream);
88   } else if (type.has_optional_type()) {
89     dump(type.optional_type(), stream);
90   } else {
91     stream << "None";
92   }
93 }
94 
dump(const onnx::ValueInfoProto & value_info,std::ostream & stream)95 void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
96   stream << "{name: \"" << value_info.name() << "\", type:";
97   dump(value_info.type(), stream);
98   stream << "}";
99 }
100 
101 void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
102 
dump(const onnx::AttributeProto & attr,std::ostream & stream,size_t indent)103 void dump(
104     const onnx::AttributeProto& attr,
105     std::ostream& stream,
106     size_t indent) {
107   stream << "{ name: '" << attr.name() << "', type: ";
108   if (attr.has_f()) {
109     stream << "float, value: " << attr.f();
110   } else if (attr.has_i()) {
111     stream << "int, value: " << attr.i();
112   } else if (attr.has_s()) {
113     stream << "string, value: '" << attr.s() << "'";
114   } else if (attr.has_g()) {
115     stream << "graph, value:\n";
116     dump(attr.g(), stream, indent + 1);
117     stream << nlidt(indent);
118   } else if (attr.has_t()) {
119     stream << "tensor, value:";
120     dump(attr.t(), stream);
121   } else if (attr.floats_size()) {
122     stream << "floats, values: [";
123     for (const auto i : c10::irange(attr.floats_size())) {
124       stream << attr.floats(i) << (i == attr.floats_size() - 1 ? "" : " ");
125     }
126     stream << "]";
127   } else if (attr.ints_size()) {
128     stream << "ints, values: [";
129     for (const auto i : c10::irange(attr.ints_size())) {
130       stream << attr.ints(i) << (i == attr.ints_size() - 1 ? "" : " ");
131     }
132     stream << "]";
133   } else if (attr.strings_size()) {
134     stream << "strings, values: [";
135     for (const auto i : c10::irange(attr.strings_size())) {
136       stream << "'" << attr.strings(i) << "'"
137              << (i == attr.strings_size() - 1 ? "" : " ");
138     }
139     stream << "]";
140   } else if (attr.tensors_size()) {
141     stream << "tensors, values: [";
142     for (auto& t : attr.tensors()) {
143       dump(t, stream);
144     }
145     stream << "]";
146   } else if (attr.graphs_size()) {
147     stream << "graphs, values: [";
148     for (auto& g : attr.graphs()) {
149       dump(g, stream, indent + 1);
150     }
151     stream << "]";
152   } else {
153     stream << "UNKNOWN";
154   }
155   stream << "}";
156 }
157 
dump(const onnx::NodeProto & node,std::ostream & stream,size_t indent)158 void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) {
159   stream << "Node {type: \"" << node.op_type() << "\", inputs: [";
160   for (const auto i : c10::irange(node.input_size())) {
161     stream << node.input(i) << (i == node.input_size() - 1 ? "" : ",");
162   }
163   stream << "], outputs: [";
164   for (const auto i : c10::irange(node.output_size())) {
165     stream << node.output(i) << (i == node.output_size() - 1 ? "" : ",");
166   }
167   stream << "], attributes: [";
168   for (const auto i : c10::irange(node.attribute_size())) {
169     dump(node.attribute(i), stream, indent + 1);
170     stream << (i == node.attribute_size() - 1 ? "" : ",");
171   }
172   stream << "]}";
173 }
174 
dump(const onnx::GraphProto & graph,std::ostream & stream,size_t indent)175 void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
176   stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \""
177          << graph.name() << "\"" << nlidt(indent + 1) << "inputs: [";
178   for (const auto i : c10::irange(graph.input_size())) {
179     dump(graph.input(i), stream);
180     stream << (i == graph.input_size() - 1 ? "" : ",");
181   }
182   stream << "]" << nlidt(indent + 1) << "outputs: [";
183   for (const auto i : c10::irange(graph.output_size())) {
184     dump(graph.output(i), stream);
185     stream << (i == graph.output_size() - 1 ? "" : ",");
186   }
187   stream << "]" << nlidt(indent + 1) << "value_infos: [";
188   for (const auto i : c10::irange(graph.value_info_size())) {
189     dump(graph.value_info(i), stream);
190     stream << (i == graph.value_info_size() - 1 ? "" : ",");
191   }
192   stream << "]" << nlidt(indent + 1) << "initializers: [";
193   for (const auto i : c10::irange(graph.initializer_size())) {
194     dump(graph.initializer(i), stream);
195     stream << (i == graph.initializer_size() - 1 ? "" : ",");
196   }
197   stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2);
198   for (const auto i : c10::irange(graph.node_size())) {
199     dump(graph.node(i), stream, indent + 2);
200     if (i != graph.node_size() - 1) {
201       stream << "," << nlidt(indent + 2);
202     }
203   }
204   stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n";
205 }
206 
dump(const onnx::OperatorSetIdProto & operator_set_id,std::ostream & stream)207 void dump(
208     const onnx::OperatorSetIdProto& operator_set_id,
209     std::ostream& stream) {
210   stream << "OperatorSetIdProto { domain: " << operator_set_id.domain()
211          << ", version: " << operator_set_id.version() << "}";
212 }
213 
dump(const onnx::ModelProto & model,std::ostream & stream,size_t indent)214 void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
215   stream << idt(indent) << "ModelProto {" << nlidt(indent + 1)
216          << "producer_name: \"" << model.producer_name() << "\""
217          << nlidt(indent + 1) << "domain: \"" << model.domain() << "\""
218          << nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\"";
219   if (model.has_graph()) {
220     stream << nlidt(indent + 1) << "graph:\n";
221     dump(model.graph(), stream, indent + 2);
222   }
223   if (model.opset_import_size()) {
224     stream << idt(indent + 1) << "opset_import: [";
225     for (auto& opset_imp : model.opset_import()) {
226       dump(opset_imp, stream);
227     }
228     stream << "],\n";
229   }
230   stream << idt(indent) << "}\n";
231 }
232 
233 } // namespace
234 
prettyPrint(const::ONNX_NAMESPACE::ModelProto & model)235 std::string prettyPrint(const ::ONNX_NAMESPACE::ModelProto& model) {
236   std::ostringstream ss;
237   dump(model, ss, 0);
238   return ss.str();
239 }
240 
241 } // namespace torch::jit
242