xref: /aosp_15_r20/external/executorch/kernels/prim_ops/et_copy_index.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/prim_ops/et_copy_index.h>
10 
11 #include <cstring>
12 
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 #include <executorch/runtime/platform/assert.h>
15 
16 using exec_aten::SizesType;
17 using exec_aten::Tensor;
18 using torch::executor::Error;
19 using torch::executor::resize_tensor;
20 
21 namespace torch {
22 namespace executor {
23 namespace function {
24 
25 constexpr size_t kTensorDimensionLimit = 16;
26 
27 // This operator is currently only intended for use to support the map operator.
28 // Below is a model with the map operator in it.
29 // def map_fn(x,y):
30 //    return x+y
31 //
32 // class TestMapCond(torch.nn.Module):
33 //    def __init__(self):
34 //        super().__init__()
35 //
36 //    def forward(self, x,y):
37 //        return control_flow.map(map_fn, x, y)
38 //
39 // Corresponding graph:
40 //    def forward(self, arg0_1, arg1_1):
41 //        submodule_0 = self.submodule_0
42 //        map_1 = torch.ops.map(submodule_0, arg0_1, arg1_1);  submodule_0 =
43 //        arg0_1 = arg1_1 = None return [map_1]
44 //
45 //    def forward(self, arg0_1, arg1_1):
46 //        add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 =
47 //        arg1_1 = None
48 //        return add_tensor
49 // Post the transformations by the emitter to handle the map loop this is what
50 // the submodule that map calls will look like.
51 //   def forward(self, arg0_1, arg1_1):
52 //        sym_size = torch.ops.aten.sym_size(arg0_1)
53 //        # Emitter creates a variable here to track iteration index
54 //        select_copy_tensor = torch.ops.aten.select(arg0_1, 0, iteration_index)
55 //        add_tensor = torch.ops.aten.add.Tensor(select_copy_tensor, arg1_1);
56 //        arg0_1 = arg1_1 = None output_of_map =
57 //        torch.ops.executorch.prim.et_copy_index(output_of_map, add_tensor,
58 //        iteration_index) iteration_index =
59 //        torch.ops.executorch.prim.add.int(iteration_index, 1, iteration_index)
60 //        done_bool = torch.ops.executorch.prim.eq.int(iteration_index,
61 //        sym_size, done_bool) # Emitter inserts a instruction here, if
62 //        done_bool == False jump to selcect_copy op # if not continue. return
63 //        add_tensor
64 //
65 // The output of each iteration (copy_from) is copied into the copy_to tensor at
66 // the specified index. This operator is supported in both ATen and lean modes.
et_copy_index(KernelRuntimeContext & context,EValue ** stack)67 void et_copy_index(KernelRuntimeContext& context, EValue** stack) {
68   (void)context;
69   SizesType expected_output_size[kTensorDimensionLimit];
70 
71   auto copy_to = (*stack[0]).toTensor();
72   auto copy_from = (*stack[1]).toTensor();
73   auto index = (*stack[2]).toInt();
74 
75   // Number of bytes we need to copy over from copy_from tensor.
76   size_t size_copy_from = (copy_from.element_size()) * (copy_from.numel());
77 
78   ET_CHECK_MSG(
79       (copy_to.sizes().size() - copy_from.sizes().size()) == 1,
80       "Ranks of copy_to  and copy_from tensor should only differ by 1.");
81 
82   // Here we calculate the size of the out_tensor after copy_from has
83   // been copied to it. This will be passed onto the resize call.
84   expected_output_size[0] = index + 1;
85   for (size_t i = 0; i < copy_from.sizes().size(); i++) {
86     // If we're copying past the first index then the shape of
87     // copy_from and copy_to without the leading dimension should be
88     // the same. i.e. copy_to.size[1:] == copy_from.size[:].
89     if (index > 0) {
90       ET_CHECK_MSG(
91           copy_to.sizes()[i + 1] == copy_from.sizes()[i],
92           "Mismatch in shape between copy_to and copy_from tensors");
93     }
94     expected_output_size[i + 1] = copy_from.sizes()[i];
95   }
96 
97   if (copy_to.sizes()[0] < expected_output_size[0]) {
98     // Resize `copy_to` to the expected output size.
99     const void* data_ptr = copy_to.const_data_ptr();
100     Error err =
101         resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()});
102     ET_CHECK(err == Error::Ok);
103     ET_CHECK_MSG(
104         data_ptr == copy_to.const_data_ptr(),
105         "Data ptr of copy_to tensor changed after resize which isn't allowed for static/upper-bounded tensors");
106   }
107 
108   auto copy_to_ptr = copy_to.const_data_ptr();
109   auto copy_from_ptr = copy_from.const_data_ptr();
110 
111   // If we've reached here, it means the copy_to tensor has been
112   // successfully resized so we can now copy over the data from
113   // copy_from into the copy_to tensor.
114   memcpy(
115       (void*)((uintptr_t)copy_to_ptr + index * size_copy_from),
116       copy_from_ptr,
117       size_copy_from);
118 }
119 
120 } // namespace function
121 } // namespace executor
122 } // namespace torch
123