xref: /aosp_15_r20/external/pytorch/torch/csrc/onnx/back_compat.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <onnx/onnx_pb.h>
4 
5 namespace torch::onnx {
6 
7 // The following constants are defined here to avoid breaking Meta's internal
8 // usage of ONNX which pre-dates ONNX 1.14 and thus does not support FLOAT8:
9 // cf. https://github.com/pytorch/pytorch/pull/106379#issuecomment-1675189340
10 // -abock, 2023-08-25
11 //
12 // ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN
13 constexpr auto TensorProto_DataType_FLOAT8E4M3FN =
14     static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(17);
15 // ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ
16 constexpr auto TensorProto_DataType_FLOAT8E4M3FNUZ =
17     static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(18);
18 // ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2
19 constexpr auto TensorProto_DataType_FLOAT8E5M2 =
20     static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(19);
21 // ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ
22 constexpr auto TensorProto_DataType_FLOAT8E5M2FNUZ =
23     static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(20);
24 
25 } // namespace torch::onnx
26