1 #pragma once
2
3 #include <ATen/Utils.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/runtime/static/impl.h>
6
7 namespace at::native {
8 at::Tensor& reshape_copy_out(
9 at::Tensor& out,
10 const at::Tensor& self,
11 const at::DimVector& proposed_shape,
12 bool infer_size = true);
13 at::Tensor& to_copy_out(
14 Tensor& out,
15 const Tensor& self,
16 bool non_blocking,
17 bool copy_strides,
18 std::optional<MemoryFormat> memory_format);
19 } // namespace at::native
20
21 namespace torch::jit {
22
23 using SROpFunctor = SROperator (*)(Node* n);
24 struct SROperatorFunctor {
GenerateSROperatorFunctor25 virtual SROperator Generate(Node*) {
26 SROperator out;
27 return out;
28 }
29 virtual ~SROperatorFunctor() = default;
30 };
31
32 TORCH_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
33
34 #define REGISTER_OPERATOR_FUNCTOR(name, id, ...) \
35 struct SROperatorFunctor_##id : public SROperatorFunctor { \
36 const SROpFunctor fn = __VA_ARGS__; \
37 SROperator Generate(Node* n) override { \
38 return fn(n); \
39 } \
40 }; \
41 C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id);
42
43 TORCH_DECLARE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
44 #define REGISTER_NATIVE_OPERATOR_FUNCTOR(name, id, ...) \
45 struct SRNativeOperatorFunctor_##id : public SROperatorFunctor { \
46 const SROpFunctor fn = __VA_ARGS__; \
47 SROperator Generate(Node* n) override { \
48 return fn(n); \
49 } \
50 }; \
51 C10_REGISTER_CLASS( \
52 SRNativeOperatorRegistry, name, SRNativeOperatorFunctor_##id);
53
create_empty_from(const at::Tensor & t)54 inline at::Tensor create_empty_from(const at::Tensor& t) {
55 return at::detail::empty_cpu(
56 {0},
57 c10::typeMetaToScalarType(t.dtype()),
58 t.layout(),
59 t.device(),
60 std::nullopt,
61 std::nullopt);
62 }
63
create_empty_from(at::IntArrayRef sizes,const at::Tensor & t)64 inline at::Tensor create_empty_from(
65 at::IntArrayRef sizes,
66 const at::Tensor& t) {
67 return at::detail::empty_cpu(
68 sizes,
69 c10::typeMetaToScalarType(t.dtype()),
70 t.layout(),
71 t.device(),
72 std::nullopt,
73 std::nullopt);
74 }
75
create_empty(c10::ScalarType dtype)76 inline at::Tensor create_empty(c10::ScalarType dtype) {
77 return at::detail::empty_cpu(
78 {0}, dtype, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
79 }
80
create_empty_from(const at::Tensor & t,c10::ScalarType dtype)81 inline at::Tensor create_empty_from(
82 const at::Tensor& t,
83 c10::ScalarType dtype) {
84 return at::detail::empty_cpu(
85 {0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt);
86 }
87
create_empty_from(const at::Tensor & t,c10::Layout layout)88 inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) {
89 return at::detail::empty_cpu(
90 {0},
91 c10::typeMetaToScalarType(t.dtype()),
92 layout,
93 t.device(),
94 std::nullopt,
95 std::nullopt);
96 }
97
create_empty_from(const at::Tensor & t,c10::Device device)98 inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) {
99 return at::detail::empty_cpu(
100 {0},
101 c10::typeMetaToScalarType(t.dtype()),
102 t.layout(),
103 device,
104 std::nullopt,
105 std::nullopt);
106 }
107
create_empty_from(const at::Tensor & t,c10::MemoryFormat memory_format)108 inline at::Tensor create_empty_from(
109 const at::Tensor& t,
110 c10::MemoryFormat memory_format) {
111 return at::detail::empty_cpu(
112 {0},
113 c10::typeMetaToScalarType(t.dtype()),
114 t.layout(),
115 t.device(),
116 std::nullopt,
117 memory_format);
118 }
119
create_empty_from(const at::Tensor & t,c10::ScalarType dtype,c10::MemoryFormat memory_format)120 inline at::Tensor create_empty_from(
121 const at::Tensor& t,
122 c10::ScalarType dtype,
123 c10::MemoryFormat memory_format) {
124 return at::detail::empty_cpu(
125 {0}, dtype, t.layout(), t.device(), std::nullopt, memory_format);
126 }
127
checkResizedDataPtr(at::Tensor & t)128 inline bool checkResizedDataPtr(at::Tensor& t) {
129 auto const prev_data_ptr = t.data_ptr();
130 t.resize_({0});
131 return prev_data_ptr == t.data_ptr();
132 }
133
fastResizeToZero(at::Tensor & t)134 inline void fastResizeToZero(at::Tensor& t) {
135 t.unsafeGetTensorImpl()->set_sizes_contiguous({0});
136 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
137 }
138
139 // check if an op has an out variant registered in Static Runtime
140 bool opIsRegistered(const c10::Symbol& op_name);
141 // check if Static Runtime can run an op natively.
142 // prim ops that are implemented directly in the jit interpreter are implemented
143 // as native ops in Static Runtime
144 bool nativeOpIsRegistered(const c10::Symbol& op_name);
145
146 bool canReuseInputsOutputs(
147 Node* n,
148 const c10::FastMap<Node*, bool>& node_has_out_variant);
149 bool isOptimizableContainerType(
150 Node* n,
151 const c10::FastMap<Node*, bool>& node_has_out_variant);
152
153 SROperator getOutOfPlaceOperation(Node* n);
154 SROperator getNativeOperation(Node* n);
155
156 bool hasVarArgs(Node* n);
157
PrintNode(const Node * node)158 inline std::string PrintNode(const Node* node) {
159 std::ostringstream ss;
160 node->print(ss, 0, nullptr, false);
161 return ss.str();
162 }
163
LogAndDumpSchema(const Node * node)164 inline void LogAndDumpSchema(const Node* node) {
165 VLOG(1) << "Found schema mismatch for: " << node->schema();
166 }
167
sr_schema_check(torch::jit::Node *)168 inline bool sr_schema_check(torch::jit::Node*) {
169 return true;
170 }
171
172 template <typename Schema, typename... Schemas>
sr_schema_check(torch::jit::Node * node,Schema && first,Schemas &&...rest)173 bool sr_schema_check(
174 torch::jit::Node* node,
175 Schema&& first,
176 Schemas&&... rest) {
177 auto is_match = node->matches(first) || sr_schema_check(node, rest...);
178 if (!is_match) {
179 torch::jit::LogAndDumpSchema(node);
180 }
181 return is_match;
182 }
183
184 bool sr_schema_check_kind(torch::jit::Node* node, c10::Symbol node_kind);
185
186 } // namespace torch::jit
187