xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/decompose_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/decompose_ops.h>
2 
3 #include <torch/csrc/jit/frontend/ir_emitter.h>
4 #include <torch/csrc/jit/passes/constant_propagation.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/shape_analysis.h>
7 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
8 #include <torch/csrc/jit/runtime/custom_operator.h>
9 #include <torch/csrc/jit/runtime/operator.h>
10 
11 #include <ATen/core/symbol.h>
12 
13 namespace torch::jit {
14 
15 namespace {
aliasAnalysisFromSchema()16 c10::AliasAnalysisKind aliasAnalysisFromSchema() {
17   return c10::AliasAnalysisKind::FROM_SCHEMA;
18 }
19 } // namespace
20 
21 // helper to determine if an optional tensor argument/value passed in is
22 // statically defined (neither a None constant nor a Optional[Tensor] type)
23 // return yes, no, or no value if we can't tell
isDefined(Value * tensor)24 static std::optional<bool> isDefined(Value* tensor) {
25   if (tensor->type()->isSubtypeOf(*TensorType::get())) {
26     return true;
27   }
28   if (tensor->node()->mustBeNone()) {
29     return false;
30   }
31   return {};
32 }
33 
isDecomposableNorm(Node * normalize_op)34 static bool isDecomposableNorm(Node* normalize_op) {
35   static const OperatorSet decomposable_normalization_ops = {
36       "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
37       "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
38   };
39   Value* input = normalize_op->namedInput(attr::input);
40   if (!input->type()->isSubtypeOf(*TensorType::get())) {
41     return false;
42   }
43   auto device = input->type()->expectRef<TensorType>().device();
44   // As of now, we do the decomposition for batchnorm/layernorm on GPU device
45   // only
46   if (!device || !(*device).is_cuda()) {
47     return false;
48   }
49 
50   if (normalize_op->isMemberOf(decomposable_normalization_ops)) {
51     // If we can't determine if weight and bias is defined statically there's
52     // really no point in decomposing normalization into simpler ops, since it
53     // won't get fused into a single kernel.
54     return isDefined(normalize_op->namedInput(attr::weight)).has_value() &&
55         isDefined(normalize_op->namedInput(attr::bias)).has_value();
56   }
57   return false;
58 }
59 
60 RegisterOperators reg_ops(
61     {Operator(
62          "aten::_ncf_unsqueeze(Tensor(a) self, int ndim) -> Tensor(a)",
__anon90d7ec1d0202(Stack& stack) 63          [](Stack& stack) {
64            const int64_t ndim = pop(stack).toInt();
65            auto self = pop(stack).toTensor();
66            c10::SmallVector<int64_t, 8> sizes(ndim, 1);
67            AT_ASSERT(self.dim() == 1);
68            sizes.at(1) = self.size(0);
69            push(stack, self.reshape(sizes));
70          },
71          aliasAnalysisFromSchema()),
72      Operator(
73          "aten::_ncf_view(Tensor(a) self, int[] input_shape, int normalized_ndim) -> Tensor(a)",
__anon90d7ec1d0302(Stack& stack) 74          [](Stack& stack) {
75            const int64_t normalized_ndim = pop(stack).toInt();
76            auto input_shape = pop(stack).toIntList();
77            auto self = pop(stack).toTensor();
78            const int64_t input_ndim = input_shape.size();
79            c10::SmallVector<int64_t, 8> sizes(input_ndim, 1);
80            for (int i = 0; i < input_ndim - normalized_ndim; ++i) {
81              sizes.at(i) = input_shape.get(i);
82            }
83            push(stack, self.reshape(sizes));
84          },
85          aliasAnalysisFromSchema())});
86 
DecomposeOps(Block * block,CompilationUnit & decompose_funcs)87 static bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
88   bool decomposed = false;
89   for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
90        ++it) {
91     for (auto sub : it->blocks()) {
92       DecomposeOps(sub, decompose_funcs);
93     }
94 
95     if (it->matches(
96             "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
97             /*const_inputs=*/{attr::beta, attr::alpha})) {
98       // For the case where we have an addmm where alpha and beta are Attributes
99       // and both of those scalars are equal to 1.0, decompose this into an mm
100       // followed by an add so that it can go through the existing optimization
101       // (batchmm)
102       if (it->get<at::Scalar>(attr::alpha)->toComplexDouble() != 1.0 ||
103           it->get<at::Scalar>(attr::beta)->toComplexDouble() != 1.0) {
104         continue;
105       }
106 
107       decomposed = true;
108       WithInsertPoint guard(*it);
109       std::shared_ptr<Graph> d_graph =
110           toGraphFunction(decompose_funcs.get_function("addmm")).graph();
111       Value* new_output =
112           insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
113       // Set the output of the decomposed graph to have the same output type as
114       // the original op otherwise the canonicalized graph will have TensorType
115       // as the output of this node which is incorrect
116       new_output->setType(it->output()->type());
117       it->output()->replaceAllUsesWith(new_output);
118       it.destroyCurrent();
119     } else if (
120         it->matches(
121             "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
122       if (!isDecomposableNorm(*it)) {
123         continue;
124       }
125       decomposed = true;
126       WithInsertPoint insert_guard{*it};
127       Graph* graph = it->owningGraph();
128       Value* input = it->namedInput(attr::input);
129       Value* input_dim = graph->insert(aten::dim, {input});
130       std::vector<Value*> inputs{
131           input,
132           it->namedInput(attr::running_mean),
133           it->namedInput(attr::running_var),
134           it->namedInput(attr::training),
135           it->namedInput(attr::momentum),
136           it->namedInput(attr::eps)};
137 
138       // inline the compiled decomposed batchnorm
139       std::shared_ptr<Graph> d_graph =
140           toGraphFunction(decompose_funcs.get_function("batch_norm")).graph();
141       Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
142 
143       // post processing the graph
144       Value* weight = it->namedInput(attr::weight);
145       Value* bias = it->namedInput(attr::bias);
146       if (isDefined(weight).value()) {
147         Value* expanded_weight =
148             graph->insert(aten::_ncf_unsqueeze, {weight, input_dim});
149         new_output = graph->insert(aten::mul, {new_output, expanded_weight});
150       }
151       if (isDefined(bias).value()) {
152         Value* expanded_bias =
153             graph->insert(aten::_ncf_unsqueeze, {bias, input_dim});
154         new_output = graph->insert(aten::add, {new_output, expanded_bias});
155       }
156       it->output()->replaceAllUsesWith(new_output);
157       it.destroyCurrent();
158     } else if (
159         it->matches(
160             "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor")) {
161       if (!isDecomposableNorm(*it)) {
162         continue;
163       }
164       decomposed = true;
165       WithInsertPoint insert_guard{*it};
166       Graph* graph = it->owningGraph();
167       std::vector<Value*> inputs{
168           it->namedInput(attr::input),
169           it->namedInput(attr::normalized_shape),
170           it->namedInput(attr::eps),
171           it->namedInput(attr::cudnn_enable)};
172 
173       // inline the compiled decomposed layernorm
174       std::shared_ptr<Graph> d_graph =
175           toGraphFunction(decompose_funcs.get_function("layer_norm")).graph();
176       Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
177 
178       // post processing the graph
179       Value* weight = it->namedInput(attr::weight);
180       Value* bias = it->namedInput(attr::bias);
181       if (isDefined(weight).value()) {
182         new_output = graph->insert(aten::mul, {new_output, weight});
183       }
184       if (isDefined(bias).value()) {
185         new_output = graph->insert(aten::add, {new_output, bias});
186       }
187       it->output()->replaceAllUsesWith(new_output);
188       it.destroyCurrent();
189     }
190   }
191   return decomposed;
192 }
193 
DecomposeOps(std::shared_ptr<Graph> & graph)194 void DecomposeOps(std::shared_ptr<Graph>& graph) {
195   static CompilationUnit decompose_funcs(R"SCRIPT(
196       def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: number = 1.0, alpha: number = 1.0):
197           return self + mat1.mm(mat2)
198 
199       def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
200           if training:
201               norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
202           else:
203               norm_mean = torch._unwrap_optional(running_mean)
204               norm_var = torch._unwrap_optional(running_var)
205           norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
206           norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
207           norm_invstd = 1 / (torch.sqrt(norm_var + eps))
208           return ((input - norm_mean) * norm_invstd)
209 
210       def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor:
211           input_ndim = input.dim()
212           normalized_ndim = len(normalized_shape)
213           n = 1
214           for i in range(input_ndim - normalized_ndim):
215               n *= input.size(i)
216           input_reshape = input.contiguous().view(1, n, -1)
217           mean, invstd = torch.batch_norm_stats(input_reshape, eps)
218           input_shape = input.size()
219           mean = torch._ncf_view(mean, input_shape, normalized_ndim)
220           invstd = torch._ncf_view(invstd, input_shape, normalized_ndim)
221 
222           return (input - mean) * invstd
223       )SCRIPT");
224   bool is_decomposed = DecomposeOps(graph->block(), decompose_funcs);
225   if (is_decomposed) {
226     // we only re-run those passes when the graph get decomposed
227     PropagateInputShapes(graph);
228     ConstantPropagation(graph);
229     EliminateDeadCode(graph);
230   }
231 }
232 
233 } // namespace torch::jit
234