1 #pragma once
2 #include <torch/csrc/onnx/diagnostics/generated/rules.h>
3 #include <torch/csrc/utils/pybind.h>
4 #include <string>
5
6 namespace torch::onnx::diagnostics {
7
8 /**
9 * @brief Level of a diagnostic.
10 * @details The levels are defined by the SARIF specification, and are not
11 * modifiable. For alternative categories, please use Tag instead.
12 * @todo Introduce Tag to C++ api.
13 */
14 enum class Level : uint8_t {
15 kNone,
16 kNote,
17 kWarning,
18 kError,
19 };
20
21 static constexpr const char* const kPyLevelNames[] = {
22 "NONE",
23 "NOTE",
24 "WARNING",
25 "ERROR",
26 };
27
28 // Wrappers around Python diagnostics.
29 // TODO: Move to .cpp file in following PR.
30
_PyDiagnostics()31 inline py::object _PyDiagnostics() {
32 return py::module::import("torch.onnx._internal.diagnostics");
33 }
34
_PyRule(Rule rule)35 inline py::object _PyRule(Rule rule) {
36 return _PyDiagnostics().attr("rules").attr(
37 kPyRuleNames[static_cast<uint32_t>(rule)]);
38 }
39
_PyLevel(Level level)40 inline py::object _PyLevel(Level level) {
41 return _PyDiagnostics().attr("levels").attr(
42 kPyLevelNames[static_cast<uint32_t>(level)]);
43 }
44
45 inline void Diagnose(
46 Rule rule,
47 Level level,
48 std::unordered_map<std::string, std::string> messageArgs = {}) {
49 py::object py_rule = _PyRule(rule);
50 py::object py_level = _PyLevel(level);
51
52 // TODO: statically check that size of messageArgs matches with rule.
53 py::object py_message =
54 py_rule.attr("format_message")(**py::cast(messageArgs));
55
56 // to use the `_a` literal for arguments
57 using namespace pybind11::literals;
58 _PyDiagnostics().attr("diagnose")(
59 py_rule, py_level, py_message, "cpp_stack"_a = true);
60 }
61
62 } // namespace torch::onnx::diagnostics
63