xref: /aosp_15_r20/external/pytorch/torch/csrc/onnx/diagnostics/diagnostics.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/onnx/diagnostics/generated/rules.h>
3 #include <torch/csrc/utils/pybind.h>
4 #include <string>
5 
6 namespace torch::onnx::diagnostics {
7 
8 /**
9  * @brief Level of a diagnostic.
10  * @details The levels are defined by the SARIF specification, and are not
11  * modifiable. For alternative categories, please use Tag instead.
12  * @todo Introduce Tag to C++ api.
13  */
14 enum class Level : uint8_t {
15   kNone,
16   kNote,
17   kWarning,
18   kError,
19 };
20 
21 static constexpr const char* const kPyLevelNames[] = {
22     "NONE",
23     "NOTE",
24     "WARNING",
25     "ERROR",
26 };
27 
28 // Wrappers around Python diagnostics.
29 // TODO: Move to .cpp file in following PR.
30 
_PyDiagnostics()31 inline py::object _PyDiagnostics() {
32   return py::module::import("torch.onnx._internal.diagnostics");
33 }
34 
_PyRule(Rule rule)35 inline py::object _PyRule(Rule rule) {
36   return _PyDiagnostics().attr("rules").attr(
37       kPyRuleNames[static_cast<uint32_t>(rule)]);
38 }
39 
_PyLevel(Level level)40 inline py::object _PyLevel(Level level) {
41   return _PyDiagnostics().attr("levels").attr(
42       kPyLevelNames[static_cast<uint32_t>(level)]);
43 }
44 
45 inline void Diagnose(
46     Rule rule,
47     Level level,
48     std::unordered_map<std::string, std::string> messageArgs = {}) {
49   py::object py_rule = _PyRule(rule);
50   py::object py_level = _PyLevel(level);
51 
52   // TODO: statically check that size of messageArgs matches with rule.
53   py::object py_message =
54       py_rule.attr("format_message")(**py::cast(messageArgs));
55 
56   // to use the `_a` literal for arguments
57   using namespace pybind11::literals;
58   _PyDiagnostics().attr("diagnose")(
59       py_rule, py_level, py_message, "cpp_stack"_a = true);
60 }
61 
62 } // namespace torch::onnx::diagnostics
63