1 #pragma once 2 3 #include <torch/csrc/jit/api/module.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 6 namespace torch::jit { 7 8 // Utility functions for PyTorch to ONNX conversion. 9 10 static const int OPSET_VERSION_1 = 1; 11 static const int OPSET_VERSION_9 = 9; 12 static const int OPSET_VERSION_10 = 10; 13 static const int OPSET_VERSION_11 = 11; 14 static const int OPSET_VERSION_12 = 12; 15 static const int OPSET_VERSION_13 = 13; 16 static const int OPSET_VERSION_14 = 14; 17 static const int OPSET_VERSION_15 = 15; 18 static const int OPSET_VERSION_16 = 16; 19 20 using ValueToParamPairMap = std::map<Value*, std::pair<std::string, IValue>>; 21 22 using ParamMap = std::map<std::string, IValue>; 23 24 TORCH_API void buildParamsMapFromValueToParamsMap( 25 const ValueToParamPairMap& valsToParamsMap, 26 ParamMap& paramsDict); 27 TORCH_API ValueToParamPairMap 28 buildValueToParamsMap(Block* b, const ParamMap& paramsDict); 29 TORCH_API void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap); 30 TORCH_API void eraseUnusedBlockInputs(Block* b); 31 TORCH_API void buildParamsMapFromValueToParamsMap( 32 const ValueToParamPairMap& valsToParamsMap, 33 ParamMap& paramsDict); 34 35 TORCH_API Node* addNodeToBlock( 36 Block* block, 37 Symbol kind, 38 ArrayRef<Value*> inputs); 39 40 TORCH_API Value* addInputToBlock(Block* block); 41 42 TORCH_API std::optional<at::ScalarType> ONNXTypeToATenType(int32_t onnx_type); 43 44 // Use int return type as no sable way exists to forward declare protobuf enum 45 TORCH_API int ATenTypeToOnnxType(at::ScalarType at_type); 46 47 TORCH_API void ONNXLintGraph(const std::shared_ptr<Graph>& graph); 48 49 Node* createONNXUnsqueeze( 50 Graph* graph, 51 Node* n_to_insert_before, 52 Value* input, 53 int axis, 54 int opset_version); 55 Node* createONNXConstant( 56 Graph* graph, 57 Node* n_to_insert_before, 58 at::Tensor value); 59 60 bool isValidToTransformToONNXConcatNode(Node* lc_node); 61 62 Node* transformToONNXConcatNode( 63 Graph* graph, 64 Node* lc_node, 65 bool need_new_input, 66 int opset_version); 67 68 class ScalarTypeHashFunction { 69 public: operator()70 size_t operator()(const c10::ScalarType& type) const { 71 return static_cast<size_t>(type); 72 } 73 }; 74 75 } // namespace torch::jit 76