1 #include <torch/csrc/jit/passes/decompose_ops.h>
2
3 #include <torch/csrc/jit/frontend/ir_emitter.h>
4 #include <torch/csrc/jit/passes/constant_propagation.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/shape_analysis.h>
7 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
8 #include <torch/csrc/jit/runtime/custom_operator.h>
9 #include <torch/csrc/jit/runtime/operator.h>
10
11 #include <ATen/core/symbol.h>
12
13 namespace torch::jit {
14
15 namespace {
aliasAnalysisFromSchema()16 c10::AliasAnalysisKind aliasAnalysisFromSchema() {
17 return c10::AliasAnalysisKind::FROM_SCHEMA;
18 }
19 } // namespace
20
21 // helper to determine if an optional tensor argument/value passed in is
22 // statically defined (neither a None constant nor a Optional[Tensor] type)
23 // return yes, no, or no value if we can't tell
isDefined(Value * tensor)24 static std::optional<bool> isDefined(Value* tensor) {
25 if (tensor->type()->isSubtypeOf(*TensorType::get())) {
26 return true;
27 }
28 if (tensor->node()->mustBeNone()) {
29 return false;
30 }
31 return {};
32 }
33
isDecomposableNorm(Node * normalize_op)34 static bool isDecomposableNorm(Node* normalize_op) {
35 static const OperatorSet decomposable_normalization_ops = {
36 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
37 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
38 };
39 Value* input = normalize_op->namedInput(attr::input);
40 if (!input->type()->isSubtypeOf(*TensorType::get())) {
41 return false;
42 }
43 auto device = input->type()->expectRef<TensorType>().device();
44 // As of now, we do the decomposition for batchnorm/layernorm on GPU device
45 // only
46 if (!device || !(*device).is_cuda()) {
47 return false;
48 }
49
50 if (normalize_op->isMemberOf(decomposable_normalization_ops)) {
51 // If we can't determine if weight and bias is defined statically there's
52 // really no point in decomposing normalization into simpler ops, since it
53 // won't get fused into a single kernel.
54 return isDefined(normalize_op->namedInput(attr::weight)).has_value() &&
55 isDefined(normalize_op->namedInput(attr::bias)).has_value();
56 }
57 return false;
58 }
59
60 RegisterOperators reg_ops(
61 {Operator(
62 "aten::_ncf_unsqueeze(Tensor(a) self, int ndim) -> Tensor(a)",
__anon90d7ec1d0202(Stack& stack) 63 [](Stack& stack) {
64 const int64_t ndim = pop(stack).toInt();
65 auto self = pop(stack).toTensor();
66 c10::SmallVector<int64_t, 8> sizes(ndim, 1);
67 AT_ASSERT(self.dim() == 1);
68 sizes.at(1) = self.size(0);
69 push(stack, self.reshape(sizes));
70 },
71 aliasAnalysisFromSchema()),
72 Operator(
73 "aten::_ncf_view(Tensor(a) self, int[] input_shape, int normalized_ndim) -> Tensor(a)",
__anon90d7ec1d0302(Stack& stack) 74 [](Stack& stack) {
75 const int64_t normalized_ndim = pop(stack).toInt();
76 auto input_shape = pop(stack).toIntList();
77 auto self = pop(stack).toTensor();
78 const int64_t input_ndim = input_shape.size();
79 c10::SmallVector<int64_t, 8> sizes(input_ndim, 1);
80 for (int i = 0; i < input_ndim - normalized_ndim; ++i) {
81 sizes.at(i) = input_shape.get(i);
82 }
83 push(stack, self.reshape(sizes));
84 },
85 aliasAnalysisFromSchema())});
86
DecomposeOps(Block * block,CompilationUnit & decompose_funcs)87 static bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
88 bool decomposed = false;
89 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
90 ++it) {
91 for (auto sub : it->blocks()) {
92 DecomposeOps(sub, decompose_funcs);
93 }
94
95 if (it->matches(
96 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
97 /*const_inputs=*/{attr::beta, attr::alpha})) {
98 // For the case where we have an addmm where alpha and beta are Attributes
99 // and both of those scalars are equal to 1.0, decompose this into an mm
100 // followed by an add so that it can go through the existing optimization
101 // (batchmm)
102 if (it->get<at::Scalar>(attr::alpha)->toComplexDouble() != 1.0 ||
103 it->get<at::Scalar>(attr::beta)->toComplexDouble() != 1.0) {
104 continue;
105 }
106
107 decomposed = true;
108 WithInsertPoint guard(*it);
109 std::shared_ptr<Graph> d_graph =
110 toGraphFunction(decompose_funcs.get_function("addmm")).graph();
111 Value* new_output =
112 insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
113 // Set the output of the decomposed graph to have the same output type as
114 // the original op otherwise the canonicalized graph will have TensorType
115 // as the output of this node which is incorrect
116 new_output->setType(it->output()->type());
117 it->output()->replaceAllUsesWith(new_output);
118 it.destroyCurrent();
119 } else if (
120 it->matches(
121 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
122 if (!isDecomposableNorm(*it)) {
123 continue;
124 }
125 decomposed = true;
126 WithInsertPoint insert_guard{*it};
127 Graph* graph = it->owningGraph();
128 Value* input = it->namedInput(attr::input);
129 Value* input_dim = graph->insert(aten::dim, {input});
130 std::vector<Value*> inputs{
131 input,
132 it->namedInput(attr::running_mean),
133 it->namedInput(attr::running_var),
134 it->namedInput(attr::training),
135 it->namedInput(attr::momentum),
136 it->namedInput(attr::eps)};
137
138 // inline the compiled decomposed batchnorm
139 std::shared_ptr<Graph> d_graph =
140 toGraphFunction(decompose_funcs.get_function("batch_norm")).graph();
141 Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
142
143 // post processing the graph
144 Value* weight = it->namedInput(attr::weight);
145 Value* bias = it->namedInput(attr::bias);
146 if (isDefined(weight).value()) {
147 Value* expanded_weight =
148 graph->insert(aten::_ncf_unsqueeze, {weight, input_dim});
149 new_output = graph->insert(aten::mul, {new_output, expanded_weight});
150 }
151 if (isDefined(bias).value()) {
152 Value* expanded_bias =
153 graph->insert(aten::_ncf_unsqueeze, {bias, input_dim});
154 new_output = graph->insert(aten::add, {new_output, expanded_bias});
155 }
156 it->output()->replaceAllUsesWith(new_output);
157 it.destroyCurrent();
158 } else if (
159 it->matches(
160 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor")) {
161 if (!isDecomposableNorm(*it)) {
162 continue;
163 }
164 decomposed = true;
165 WithInsertPoint insert_guard{*it};
166 Graph* graph = it->owningGraph();
167 std::vector<Value*> inputs{
168 it->namedInput(attr::input),
169 it->namedInput(attr::normalized_shape),
170 it->namedInput(attr::eps),
171 it->namedInput(attr::cudnn_enable)};
172
173 // inline the compiled decomposed layernorm
174 std::shared_ptr<Graph> d_graph =
175 toGraphFunction(decompose_funcs.get_function("layer_norm")).graph();
176 Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
177
178 // post processing the graph
179 Value* weight = it->namedInput(attr::weight);
180 Value* bias = it->namedInput(attr::bias);
181 if (isDefined(weight).value()) {
182 new_output = graph->insert(aten::mul, {new_output, weight});
183 }
184 if (isDefined(bias).value()) {
185 new_output = graph->insert(aten::add, {new_output, bias});
186 }
187 it->output()->replaceAllUsesWith(new_output);
188 it.destroyCurrent();
189 }
190 }
191 return decomposed;
192 }
193
DecomposeOps(std::shared_ptr<Graph> & graph)194 void DecomposeOps(std::shared_ptr<Graph>& graph) {
195 static CompilationUnit decompose_funcs(R"SCRIPT(
196 def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: number = 1.0, alpha: number = 1.0):
197 return self + mat1.mm(mat2)
198
199 def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
200 if training:
201 norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
202 else:
203 norm_mean = torch._unwrap_optional(running_mean)
204 norm_var = torch._unwrap_optional(running_var)
205 norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
206 norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
207 norm_invstd = 1 / (torch.sqrt(norm_var + eps))
208 return ((input - norm_mean) * norm_invstd)
209
210 def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor:
211 input_ndim = input.dim()
212 normalized_ndim = len(normalized_shape)
213 n = 1
214 for i in range(input_ndim - normalized_ndim):
215 n *= input.size(i)
216 input_reshape = input.contiguous().view(1, n, -1)
217 mean, invstd = torch.batch_norm_stats(input_reshape, eps)
218 input_shape = input.size()
219 mean = torch._ncf_view(mean, input_shape, normalized_ndim)
220 invstd = torch._ncf_view(invstd, input_shape, normalized_ndim)
221
222 return (input - mean) * invstd
223 )SCRIPT");
224 bool is_decomposed = DecomposeOps(graph->block(), decompose_funcs);
225 if (is_decomposed) {
226 // we only re-run those passes when the graph get decomposed
227 PropagateInputShapes(graph);
228 ConstantPropagation(graph);
229 EliminateDeadCode(graph);
230 }
231 }
232
233 } // namespace torch::jit
234