xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/ops.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Utils.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/runtime/static/impl.h>
6 
7 namespace at::native {
8 at::Tensor& reshape_copy_out(
9     at::Tensor& out,
10     const at::Tensor& self,
11     const at::DimVector& proposed_shape,
12     bool infer_size = true);
13 at::Tensor& to_copy_out(
14     Tensor& out,
15     const Tensor& self,
16     bool non_blocking,
17     bool copy_strides,
18     std::optional<MemoryFormat> memory_format);
19 } // namespace at::native
20 
21 namespace torch::jit {
22 
23 using SROpFunctor = SROperator (*)(Node* n);
24 struct SROperatorFunctor {
GenerateSROperatorFunctor25   virtual SROperator Generate(Node*) {
26     SROperator out;
27     return out;
28   }
29   virtual ~SROperatorFunctor() = default;
30 };
31 
32 TORCH_DECLARE_REGISTRY(SROperatorRegistry, SROperatorFunctor);
33 
34 #define REGISTER_OPERATOR_FUNCTOR(name, id, ...)             \
35   struct SROperatorFunctor_##id : public SROperatorFunctor { \
36     const SROpFunctor fn = __VA_ARGS__;                      \
37     SROperator Generate(Node* n) override {                  \
38       return fn(n);                                          \
39     }                                                        \
40   };                                                         \
41   C10_REGISTER_CLASS(SROperatorRegistry, name, SROperatorFunctor_##id);
42 
43 TORCH_DECLARE_REGISTRY(SRNativeOperatorRegistry, SROperatorFunctor);
44 #define REGISTER_NATIVE_OPERATOR_FUNCTOR(name, id, ...)            \
45   struct SRNativeOperatorFunctor_##id : public SROperatorFunctor { \
46     const SROpFunctor fn = __VA_ARGS__;                            \
47     SROperator Generate(Node* n) override {                        \
48       return fn(n);                                                \
49     }                                                              \
50   };                                                               \
51   C10_REGISTER_CLASS(                                              \
52       SRNativeOperatorRegistry, name, SRNativeOperatorFunctor_##id);
53 
create_empty_from(const at::Tensor & t)54 inline at::Tensor create_empty_from(const at::Tensor& t) {
55   return at::detail::empty_cpu(
56       {0},
57       c10::typeMetaToScalarType(t.dtype()),
58       t.layout(),
59       t.device(),
60       std::nullopt,
61       std::nullopt);
62 }
63 
create_empty_from(at::IntArrayRef sizes,const at::Tensor & t)64 inline at::Tensor create_empty_from(
65     at::IntArrayRef sizes,
66     const at::Tensor& t) {
67   return at::detail::empty_cpu(
68       sizes,
69       c10::typeMetaToScalarType(t.dtype()),
70       t.layout(),
71       t.device(),
72       std::nullopt,
73       std::nullopt);
74 }
75 
create_empty(c10::ScalarType dtype)76 inline at::Tensor create_empty(c10::ScalarType dtype) {
77   return at::detail::empty_cpu(
78       {0}, dtype, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
79 }
80 
create_empty_from(const at::Tensor & t,c10::ScalarType dtype)81 inline at::Tensor create_empty_from(
82     const at::Tensor& t,
83     c10::ScalarType dtype) {
84   return at::detail::empty_cpu(
85       {0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt);
86 }
87 
create_empty_from(const at::Tensor & t,c10::Layout layout)88 inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) {
89   return at::detail::empty_cpu(
90       {0},
91       c10::typeMetaToScalarType(t.dtype()),
92       layout,
93       t.device(),
94       std::nullopt,
95       std::nullopt);
96 }
97 
create_empty_from(const at::Tensor & t,c10::Device device)98 inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) {
99   return at::detail::empty_cpu(
100       {0},
101       c10::typeMetaToScalarType(t.dtype()),
102       t.layout(),
103       device,
104       std::nullopt,
105       std::nullopt);
106 }
107 
create_empty_from(const at::Tensor & t,c10::MemoryFormat memory_format)108 inline at::Tensor create_empty_from(
109     const at::Tensor& t,
110     c10::MemoryFormat memory_format) {
111   return at::detail::empty_cpu(
112       {0},
113       c10::typeMetaToScalarType(t.dtype()),
114       t.layout(),
115       t.device(),
116       std::nullopt,
117       memory_format);
118 }
119 
create_empty_from(const at::Tensor & t,c10::ScalarType dtype,c10::MemoryFormat memory_format)120 inline at::Tensor create_empty_from(
121     const at::Tensor& t,
122     c10::ScalarType dtype,
123     c10::MemoryFormat memory_format) {
124   return at::detail::empty_cpu(
125       {0}, dtype, t.layout(), t.device(), std::nullopt, memory_format);
126 }
127 
checkResizedDataPtr(at::Tensor & t)128 inline bool checkResizedDataPtr(at::Tensor& t) {
129   auto const prev_data_ptr = t.data_ptr();
130   t.resize_({0});
131   return prev_data_ptr == t.data_ptr();
132 }
133 
fastResizeToZero(at::Tensor & t)134 inline void fastResizeToZero(at::Tensor& t) {
135   t.unsafeGetTensorImpl()->set_sizes_contiguous({0});
136   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(checkResizedDataPtr(t));
137 }
138 
139 // check if an op has an out variant registered in Static Runtime
140 bool opIsRegistered(const c10::Symbol& op_name);
141 // check if Static Runtime can run an op natively.
142 // prim ops that are implemented directly in the jit interpreter are implemented
143 // as native ops in Static Runtime
144 bool nativeOpIsRegistered(const c10::Symbol& op_name);
145 
146 bool canReuseInputsOutputs(
147     Node* n,
148     const c10::FastMap<Node*, bool>& node_has_out_variant);
149 bool isOptimizableContainerType(
150     Node* n,
151     const c10::FastMap<Node*, bool>& node_has_out_variant);
152 
153 SROperator getOutOfPlaceOperation(Node* n);
154 SROperator getNativeOperation(Node* n);
155 
156 bool hasVarArgs(Node* n);
157 
PrintNode(const Node * node)158 inline std::string PrintNode(const Node* node) {
159   std::ostringstream ss;
160   node->print(ss, 0, nullptr, false);
161   return ss.str();
162 }
163 
LogAndDumpSchema(const Node * node)164 inline void LogAndDumpSchema(const Node* node) {
165   VLOG(1) << "Found schema mismatch for: " << node->schema();
166 }
167 
sr_schema_check(torch::jit::Node *)168 inline bool sr_schema_check(torch::jit::Node*) {
169   return true;
170 }
171 
172 template <typename Schema, typename... Schemas>
sr_schema_check(torch::jit::Node * node,Schema && first,Schemas &&...rest)173 bool sr_schema_check(
174     torch::jit::Node* node,
175     Schema&& first,
176     Schemas&&... rest) {
177   auto is_match = node->matches(first) || sr_schema_check(node, rest...);
178   if (!is_match) {
179     torch::jit::LogAndDumpSchema(node);
180   }
181   return is_match;
182 }
183 
184 bool sr_schema_check_kind(torch::jit::Node* node, c10::Symbol node_kind);
185 
186 } // namespace torch::jit
187