1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <memory> 6 7 namespace torch { 8 namespace jit { 9 // Moved from shape_analysis.cpp 10 11 // Requirements: 12 // dims : preserved from the first argument 13 // scalar type : preserved from the first argument (doesn't have to 14 // match other arguments) 15 // device : always matching and preserved 16 // tensor inputs : * 17 // tensor outputs : 1 18 // NB: those ops (with slight adjustments) are good candidates for restarts. 19 // Knowing the type and device of weights or biases is usually enough to 20 // infer the output type. 21 std::shared_ptr<OperatorSet> nn_ops_first_input_preserving(); 22 23 // Requirements: 24 // dims : Changed from first argument 25 // scalar type : preserved from the first argument 26 // device : always matching and preserved 27 // tensor inputs : 1 28 // tensor outputs : 1 29 std::shared_ptr<OperatorSet> ops_one_tensor_in_shape_transform(); 30 } // namespace jit 31 } // namespace torch 32