xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/passes/dead_code_elimination.h>
3 #include <torch/csrc/jit/passes/erase_number_types.h>
4 #include <torch/csrc/jit/passes/onnx.h>
5 #include <torch/csrc/jit/passes/onnx/pattern_conversion/common.h>
6 #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h>
7 
8 #include <ATen/ScalarOps.h>
9 
10 #include <iostream>
11 
12 // EDITING THIS FILE? READ THIS FIRST!
13 // see Note [Edit Pattern Conversion] in pattern_conversion.h
14 
15 namespace torch {
16 namespace jit {
17 
18 // Converting inplace index_put to ONNX
19 namespace {
20 
CreateSizeOfDim(Value * input,int64_t dim,Node * insertBefore)21 Value* CreateSizeOfDim(Value* input, int64_t dim, Node* insertBefore) {
22   auto graph = input->owningGraph();
23   WithInsertPoint guard(insertBefore);
24   auto size = graph->insert(aten::size, {input, dim});
25   return size;
26 }
27 
ConvertSelectToIndex(Value * index,Node * insertBefore)28 Value* ConvertSelectToIndex(Value* index, Node* insertBefore) {
29   // Create index tensor based on index input of aten::select node.
30   auto graph = insertBefore->owningGraph();
31   WithInsertPoint guard(insertBefore);
32   return graph->insert(aten::unsqueeze, {index, 0});
33 }
34 
ConvertSliceToIndex(Node * slice,Value * size,Node * insertBefore)35 Value* ConvertSliceToIndex(Node* slice, Value* size, Node* insertBefore) {
36   // Create index tensor based on aten::slice node.
37   auto graph = slice->owningGraph();
38   WithInsertPoint guard(insertBefore);
39   TORCH_INTERNAL_ASSERT((slice->inputs()).size() == 5);
40   auto start = slice->inputs()[2];
41   auto end = slice->inputs()[3];
42   auto step = slice->inputs()[4];
43   auto index =
44       graph->insert(aten::arange, {size}, {NamedValue("dtype", c10::kLong)});
45   auto sliced_index_n = graph->create(
46       aten::slice,
47       {index,
48        graph->insertConstant(
49            scalar_to_tensor(at::Scalar(0)), std::nullopt, slice->scope()),
50        start,
51        end,
52        step});
53 
54   sliced_index_n->copyMetadata(insertBefore);
55   auto sliced_index = sliced_index_n->insertBefore(insertBefore)->output();
56   return sliced_index;
57 }
58 
59 struct ConvertedIndex {
ConvertedIndextorch::jit::__anon9ee549340111::ConvertedIndex60   ConvertedIndex(Value* index, c10::Symbol orig_node_kind)
61       : index(index), orig_node_kind(orig_node_kind) {}
62 
63   Value* index = nullptr;
64   c10::Symbol orig_node_kind;
65 };
66 
MergeSliceAndSelectToIndices(Graph * graph,Node * index_put_node,const std::vector<Node * > & slice_and_select_nodes,Value * orig_data,const py::dict & env)67 std::unordered_map<int64_t, ConvertedIndex> MergeSliceAndSelectToIndices(
68     Graph* graph,
69     Node* index_put_node,
70     const std::vector<Node*>& slice_and_select_nodes,
71     Value* orig_data,
72     const py::dict& env) {
73   std::unordered_map<int64_t, ConvertedIndex> dim_index_map;
74 
75   // Loop over fetched slice and select nodes and convert them to index tensors.
76   // keep track of which dimension the current slice/select node is applying to.
77   int64_t cur_dim = 0;
78   int64_t dim_offset = 0;
79   const auto orig_tensor_indices = index_put_node->input(1)->node()->inputs();
80   for (auto it = slice_and_select_nodes.rbegin();
81        it != slice_and_select_nodes.rend();
82        ++it) {
83     auto node = *it;
84     // select does not keep dims,
85     // this creates offset for latter slice and select nodes.
86     // NOTE: Cannot rely on get(attr::dim), because op no longer match schema.
87     int64_t dim = node->inputs().at(1)->node()->t(attr::value).item().toLong();
88 
89     if (dim < 0) {
90       // auto input_type = env.at(orig_data)->type()->expect<TensorType>();
91       auto py_value = env[py::cast(orig_data)];
92       Value* value = py_value.cast<Value*>();
93       auto input_type = value->type()->expect<TensorType>();
94       if (input_type->dim().has_value()) {
95         auto rank = static_cast<int64_t>(input_type->dim().value());
96         // Rank of original tensor to index on.
97         // Minus the offset created by select operators.
98         dim = dim + rank - dim_offset;
99       } else {
100         std::cerr
101             << "Error: Cannot export ellipsis indexing for input "
102             << "of unknown rank. Check https://pytorch.org/docs/stable/onnx.html#indexing"
103             << "for details.";
104       }
105     }
106     dim = dim + dim_offset;
107     while (cur_dim < dim) {
108       // Handle skipped dims, these are created from ..., or tensor indices
109       // E.g.: x[torch.tensor([1, 0]), ..., 0] = update, where x has rank 3.
110       // Both torch.tensor([1, 0]) and ... are skipped, we only observe
111       // aten::select node with dim == 2. Tensor indices will be handled later.
112       // Ellipsis(...) are treated as a complete slice over the axes, thus we
113       // create index tensors here accordingly.
114       if (cur_dim - dim_offset >= (int64_t)orig_tensor_indices.size() ||
115           index_put_node->input(1)
116               ->node()
117               ->input(cur_dim - dim_offset)
118               ->node()
119               ->mustBeNone()) {
120         auto size = CreateSizeOfDim(orig_data, cur_dim, index_put_node);
121         WithInsertPoint guard(index_put_node);
122         auto index_tensor = graph->insert(
123             aten::arange, {size}, {NamedValue("dtype", c10::kLong)});
124         dim_index_map.emplace(
125             std::piecewise_construct,
126             std::forward_as_tuple(cur_dim),
127             std::forward_as_tuple(index_tensor, aten::slice));
128       } else if (cur_dim - dim_offset < (int64_t)orig_tensor_indices.size()) {
129         dim_index_map.emplace(
130             std::piecewise_construct,
131             std::forward_as_tuple(cur_dim),
132             std::forward_as_tuple(
133                 orig_tensor_indices[cur_dim - dim_offset], aten::index));
134       }
135       cur_dim++;
136     }
137 
138     TORCH_INTERNAL_ASSERT(cur_dim == dim);
139     if (node->kind() == aten::slice) {
140       auto size = CreateSizeOfDim(orig_data, dim, index_put_node);
141       auto index_tensor = ConvertSliceToIndex(node, size, index_put_node);
142       dim_index_map.emplace(
143           std::piecewise_construct,
144           std::forward_as_tuple(dim),
145           std::forward_as_tuple(index_tensor, aten::slice));
146     } else if (node->kind() == aten::select) {
147       auto index_tensor = ConvertSelectToIndex(node->input(2), index_put_node);
148       dim_index_map.emplace(
149           std::piecewise_construct,
150           std::forward_as_tuple(dim),
151           std::forward_as_tuple(index_tensor, aten::select));
152       dim_offset++;
153     } else {
154       TORCH_CHECK(
155           false,
156           node->kind().toDisplayString(),
157           " Expected aten::slice or aten::select.");
158     }
159 
160     cur_dim++;
161   }
162 
163   while (cur_dim - dim_offset < (int64_t)orig_tensor_indices.size()) {
164     dim_index_map.emplace(
165         std::piecewise_construct,
166         std::forward_as_tuple(cur_dim),
167         std::forward_as_tuple(
168             orig_tensor_indices[cur_dim - dim_offset], aten::index));
169     cur_dim++;
170   }
171 
172   // Each dimension should have its associated index tensor.
173   TORCH_INTERNAL_ASSERT((int64_t)dim_index_map.size() == cur_dim);
174   return dim_index_map;
175 }
176 
177 // Convert slice/select operators to tensor indices.
178 // Reshape the tensor indices according to their axis.
179 // E.g.                 x[1:3, 0, ind1, ind2] = y
180 //  slice index shape:   [2,   1, 1 ]
181 //  select index shape:  [     1, 1 ]
182 //  ind1 shape:          [        _ ]
183 //  ind2 shape:          [        _ ]
184 // where _ is the original size of ind1 and ind2.
185 // ind1 and ind2 are both 1-d tensors since currently we only supports 1-d
186 // tensor indices.
ReshapeToAdvancedIndexingFormat(Graph * graph,Node * index_put_node,std::unordered_map<int64_t,ConvertedIndex> & dim_index_map)187 std::vector<Value*> ReshapeToAdvancedIndexingFormat(
188     Graph* graph,
189     Node* index_put_node,
190     std::unordered_map<int64_t, ConvertedIndex>& dim_index_map) {
191   std::vector<Value*> indices;
192 
193   size_t min_index_dim = dim_index_map.size();
194   size_t max_index_dim = 0;
195   size_t tensor_ind_count = 0;
196   for (const auto i : c10::irange(dim_index_map.size())) {
197     auto index_i = dim_index_map.find(i);
198     TORCH_INTERNAL_ASSERT(index_i != dim_index_map.end());
199     if (index_i->second.orig_node_kind == aten::index) {
200       if (i < min_index_dim)
201         min_index_dim = i;
202       if (i > max_index_dim)
203         max_index_dim = i;
204       tensor_ind_count++;
205     }
206   }
207 
208   if (((max_index_dim - min_index_dim + 1) != tensor_ind_count) &&
209       tensor_ind_count != 0) {
210     TORCH_CHECK(
211         false,
212         "Only consecutive 1-d tensor indices are supported in exporting aten::index_put to ONNX.",
213         "Check https://pytorch.org/docs/stable/onnx.html#indexing for details");
214   }
215 
216   size_t tensor_ind_offset = tensor_ind_count == 0 ? 0 : tensor_ind_count - 1;
217   WithInsertPoint guard(index_put_node);
218   for (const auto i : c10::irange(dim_index_map.size())) {
219     size_t ind_size = 0;
220     auto index_i = dim_index_map.find(i);
221     TORCH_INTERNAL_ASSERT(index_i != dim_index_map.end());
222     Value* index = index_i->second.index;
223     switch (index_i->second.orig_node_kind) {
224       case aten::select:
225       case aten::slice: {
226         if (i < min_index_dim) {
227           ind_size = dim_index_map.size() - tensor_ind_offset - i;
228         } else {
229           ind_size = dim_index_map.size() - i;
230         }
231         break;
232       }
233 
234       case aten::index: {
235         ind_size = dim_index_map.size() - tensor_ind_offset - min_index_dim;
236         break;
237       }
238       default:
239         TORCH_CHECK(
240             false, "Unexpected node kind ", index_i->second.orig_node_kind);
241     }
242 
243     if (ind_size != 1) {
244       std::vector<int64_t> view_shape(ind_size, 1);
245       view_shape[0] = -1;
246       auto unsqueezed_index = graph->insert(aten::view, {index, view_shape});
247       indices.emplace_back(unsqueezed_index);
248     } else {
249       indices.emplace_back(index);
250     }
251   }
252 
253   return indices;
254 }
255 
256 // Trace back all the slice & select nodes associated with the index_put node,
257 // and convert them to associated indices.
258 // E.g. The IR for x[1:3, 0] = update
259 //    ...
260 //    %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
261 //    ...
262 //    %11 : Float(2) = aten::select(%8, %9, %10)
263 //    ...
264 //    %13 : Tensor?[] = prim::ListConstruct()
265 //    ...
266 //    %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
267 // The aten::index_put node alone does not contain any indices (%13 : Tensor?[]
268 // = prim::ListConstruct()).
269 //    ...
270 //    # Below constructs index from slice node.
271 //    %23 : Long() = aten::size(%0, %4)
272 //    %28 : Tensor = aten::arange(%23, %24, %25, %26, %27)
273 //    %33 : Tensor = aten::slice(%28, %4, %5, %6, %7)
274 //    %39 : int[] = prim::Constant[value=[-1, 1]]()
275 //    %40 : Tensor = aten::view(%33, %39)
276 //    ...
277 //    # Below constructs index from select node.
278 //    %36 : int = prim::Constant[value=0]()
279 //    %37 : Tensor = aten::unsqueeze(%10, %36)
280 //    %42 : int[] = prim::Constant[value=[-1]]()
281 //    %43 : Tensor = aten::view(%37, %42)
282 //    ...
283 //    # Adding the above two indices to index_put
284 //    %44 : Tensor?[] = prim::ListConstruct(%40, %43)
285 //    %45 : Float(2, 5) = aten::index_put(%0, %44, %14, %15)
ConvertIndexPutToONNX(Block * new_block,Node * old_node,py::dict & env,py::set & values_in_env)286 std::vector<Value*> ConvertIndexPutToONNX(
287     Block* new_block,
288     Node* old_node,
289     py::dict& env,
290     py::set& values_in_env) {
291   if (old_node->kind() != Symbol::fromQualString("onnx::Placeholder") ||
292       (old_node->s(attr::name) != "index_put" &&
293        old_node->s(attr::name) != "index_put_")) {
294     return {};
295   }
296 
297   TORCH_INTERNAL_ASSERT(old_node->blocks().size() == 1);
298   auto old_graph = old_node->owningGraph();
299   auto subblock = old_node->blocks()[0];
300   auto index_put_node = subblock->nodes().back()->prev();
301 
302   // Find slice and select operators that are associated with this index
303   // operator. E.g. x[1:3, 0] = y will generate one slice operator(1:3) and one
304   // select operator(0).
305   std::vector<Node*> slice_and_select_nodes =
306       IndexingPatternFinder::FetchSliceAndSelect(index_put_node);
307   Node* last_node = !slice_and_select_nodes.empty()
308       ? slice_and_select_nodes.back()
309       : index_put_node;
310   // Update inner block input originates from outside.
311   last_node->replaceInput(0, old_node->input(0));
312   Value* orig_data = last_node->input(0);
313 
314   // Convert slice and select operators to indices.
315   std::unordered_map<int64_t, ConvertedIndex> dim_index_map =
316       MergeSliceAndSelectToIndices(
317           old_graph, index_put_node, slice_and_select_nodes, orig_data, env);
318 
319   // Reshape indices to advanced indexing format.
320   std::vector<Value*> indices =
321       ReshapeToAdvancedIndexingFormat(old_graph, index_put_node, dim_index_map);
322 
323   // Create new index_put node with converted indices.
324   const auto list_indices =
325       old_graph->createList(OptionalType::ofTensor(), indices)
326           ->insertBefore(index_put_node)
327           ->output();
328   auto new_index_put_node = old_graph->create(
329       aten::index_put,
330       {orig_data,
331        list_indices,
332        index_put_node->input(2),
333        index_put_node->input(3)});
334   new_index_put_node->insertBefore(index_put_node);
335   new_index_put_node->copyMetadata(index_put_node);
336   auto new_index_put = new_index_put_node->output();
337   new_index_put->copyMetadata(index_put_node->output());
338   index_put_node->output()->replaceAllUsesWith(new_index_put);
339 
340   // Convert aten type to onnx type.
341   EraseNumberTypesOnBlock(subblock);
342   EliminateDeadCode(
343       subblock,
344       true,
345       DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
346 
347   // Convert all the new aten nodes that were just created to onnx.
348   // New onnx nodes are appended at the end of new_block.
349   for (auto at_n : subblock->nodes()) {
350     if (at_n == subblock->param_node() || at_n == subblock->return_node()) {
351       continue;
352     }
353 
354     NodeToONNX(
355         at_n,
356         new_block,
357         torch::onnx::OperatorExportTypes::ONNX,
358         env,
359         values_in_env);
360   }
361 
362   // Find onnx outputs corresponding to the aten outputs of index_put.
363   std::vector<Value*> outs;
364   for (auto o : subblock->return_node()->inputs()) {
365     auto py_value = env[py::cast(o)];
366     Value* value = py_value.cast<Value*>();
367     outs.emplace_back(value);
368   }
369   return outs;
370 }
371 
372 } // namespace
373 
ConvertPatternFromSubblock(Block * new_block,Node * old_node,py::dict & env,py::set & values_in_env)374 std::vector<Value*> ConvertPatternFromSubblock(
375     Block* new_block,
376     Node* old_node,
377     py::dict& env,
378     py::set& values_in_env) {
379   std::vector<Value*> res;
380 
381   if (old_node->kind() != Symbol::fromQualString("onnx::Placeholder")) {
382     return res;
383   }
384 
385   // The pattern conversion code should not alter nodes outside the Placeholder
386   // subblock.
387   auto op_name = old_node->s(attr::name);
388   if (op_name == "index_put" || op_name == "index_put_") {
389     res = ConvertIndexPutToONNX(new_block, old_node, env, values_in_env);
390   }
391 
392   return res;
393 }
394 
395 } // namespace jit
396 } // namespace torch
397