1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/prim_ops/et_copy_index.h>
10
11 #include <cstring>
12
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 #include <executorch/runtime/platform/assert.h>
15
16 using exec_aten::SizesType;
17 using exec_aten::Tensor;
18 using torch::executor::Error;
19 using torch::executor::resize_tensor;
20
21 namespace torch {
22 namespace executor {
23 namespace function {
24
25 constexpr size_t kTensorDimensionLimit = 16;
26
27 // This operator is currently only intended for use to support the map operator.
28 // Below is a model with the map operator in it.
29 // def map_fn(x,y):
30 // return x+y
31 //
32 // class TestMapCond(torch.nn.Module):
33 // def __init__(self):
34 // super().__init__()
35 //
36 // def forward(self, x,y):
37 // return control_flow.map(map_fn, x, y)
38 //
39 // Corresponding graph:
40 // def forward(self, arg0_1, arg1_1):
41 // submodule_0 = self.submodule_0
42 // map_1 = torch.ops.map(submodule_0, arg0_1, arg1_1); submodule_0 =
43 // arg0_1 = arg1_1 = None return [map_1]
44 //
45 // def forward(self, arg0_1, arg1_1):
46 // add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 =
47 // arg1_1 = None
48 // return add_tensor
49 // Post the transformations by the emitter to handle the map loop this is what
50 // the submodule that map calls will look like.
51 // def forward(self, arg0_1, arg1_1):
52 // sym_size = torch.ops.aten.sym_size(arg0_1)
53 // # Emitter creates a variable here to track iteration index
54 // select_copy_tensor = torch.ops.aten.select(arg0_1, 0, iteration_index)
55 // add_tensor = torch.ops.aten.add.Tensor(select_copy_tensor, arg1_1);
56 // arg0_1 = arg1_1 = None output_of_map =
57 // torch.ops.executorch.prim.et_copy_index(output_of_map, add_tensor,
58 // iteration_index) iteration_index =
59 // torch.ops.executorch.prim.add.int(iteration_index, 1, iteration_index)
60 // done_bool = torch.ops.executorch.prim.eq.int(iteration_index,
61 // sym_size, done_bool) # Emitter inserts a instruction here, if
62 // done_bool == False jump to selcect_copy op # if not continue. return
63 // add_tensor
64 //
65 // The output of each iteration (copy_from) is copied into the copy_to tensor at
66 // the specified index. This operator is supported in both ATen and lean modes.
et_copy_index(KernelRuntimeContext & context,EValue ** stack)67 void et_copy_index(KernelRuntimeContext& context, EValue** stack) {
68 (void)context;
69 SizesType expected_output_size[kTensorDimensionLimit];
70
71 auto copy_to = (*stack[0]).toTensor();
72 auto copy_from = (*stack[1]).toTensor();
73 auto index = (*stack[2]).toInt();
74
75 // Number of bytes we need to copy over from copy_from tensor.
76 size_t size_copy_from = (copy_from.element_size()) * (copy_from.numel());
77
78 ET_CHECK_MSG(
79 (copy_to.sizes().size() - copy_from.sizes().size()) == 1,
80 "Ranks of copy_to and copy_from tensor should only differ by 1.");
81
82 // Here we calculate the size of the out_tensor after copy_from has
83 // been copied to it. This will be passed onto the resize call.
84 expected_output_size[0] = index + 1;
85 for (size_t i = 0; i < copy_from.sizes().size(); i++) {
86 // If we're copying past the first index then the shape of
87 // copy_from and copy_to without the leading dimension should be
88 // the same. i.e. copy_to.size[1:] == copy_from.size[:].
89 if (index > 0) {
90 ET_CHECK_MSG(
91 copy_to.sizes()[i + 1] == copy_from.sizes()[i],
92 "Mismatch in shape between copy_to and copy_from tensors");
93 }
94 expected_output_size[i + 1] = copy_from.sizes()[i];
95 }
96
97 if (copy_to.sizes()[0] < expected_output_size[0]) {
98 // Resize `copy_to` to the expected output size.
99 const void* data_ptr = copy_to.const_data_ptr();
100 Error err =
101 resize_tensor(copy_to, {expected_output_size, copy_to.sizes().size()});
102 ET_CHECK(err == Error::Ok);
103 ET_CHECK_MSG(
104 data_ptr == copy_to.const_data_ptr(),
105 "Data ptr of copy_to tensor changed after resize which isn't allowed for static/upper-bounded tensors");
106 }
107
108 auto copy_to_ptr = copy_to.const_data_ptr();
109 auto copy_from_ptr = copy_from.const_data_ptr();
110
111 // If we've reached here, it means the copy_to tensor has been
112 // successfully resized so we can now copy over the data from
113 // copy_from into the copy_to tensor.
114 memcpy(
115 (void*)((uintptr_t)copy_to_ptr + index * size_copy_from),
116 copy_from_ptr,
117 size_copy_from);
118 }
119
120 } // namespace function
121 } // namespace executor
122 } // namespace torch
123