xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_apply.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/tensor_apply.h>
2 
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/TensorUtils.h>
5 #include <c10/util/irange.h>
6 
7 #include <torch/csrc/Exceptions.h>
8 #include <torch/csrc/utils/python_numbers.h>
9 #include <torch/csrc/utils/python_scalars.h>
10 
11 using namespace at;
12 
13 namespace torch::utils {
14 
15 struct StridedData {
StridedDatatorch::utils::StridedData16   StridedData(const Tensor& tensor)
17       : data(tensor.data_ptr()),
18         strides(tensor.strides()),
19         elementSize(tensor.element_size()) {}
20 
21   void* data;
22   IntArrayRef strides;
23   int64_t elementSize;
24 
steptorch::utils::StridedData25   void step(int dim) {
26     data = (char*)data + (strides[dim] * elementSize);
27   }
28 };
29 
30 template <size_t N>
recursive_apply(IntArrayRef sizes,ScalarType scalarType,int64_t dim,PyObject * fn,std::array<StridedData,N> strided_data)31 static void recursive_apply(
32     IntArrayRef sizes,
33     ScalarType scalarType,
34     int64_t dim,
35     PyObject* fn,
36     std::array<StridedData, N> strided_data) {
37   int64_t ndim = static_cast<int64_t>(sizes.size());
38   if (dim == ndim) {
39     auto args = THPObjectPtr(PyTuple_New(N));
40     if (!args)
41       throw python_error();
42     for (const auto i : c10::irange(N)) {
43       PyObject* arg = load_scalar(strided_data[i].data, scalarType);
44       if (!arg)
45         throw python_error();
46       PyTuple_SET_ITEM(args.get(), i, arg);
47     }
48     auto ret = THPObjectPtr(PyObject_CallObject(fn, args.get()));
49     if (!ret)
50       throw python_error();
51     store_scalar(strided_data[0].data, scalarType, ret.get());
52     return;
53   }
54 
55   auto n = sizes[dim];
56   for (const auto i : c10::irange(n)) {
57     (void)i; // Suppress unused variable warning
58     recursive_apply(sizes, scalarType, dim + 1, fn, strided_data);
59     for (auto& td : strided_data) {
60       td.step(dim);
61     }
62   }
63 }
64 
apply_(const Tensor & self,PyObject * fn)65 const Tensor& apply_(const Tensor& self, PyObject* fn) {
66   if (self.is_meta()) {
67     return self; // Just skip
68   }
69   TORCH_CHECK_TYPE(
70       self.device().is_cpu(), "apply_ is only implemented on CPU tensors");
71   auto scalarType = self.scalar_type();
72   recursive_apply<1>(self.sizes(), scalarType, 0, fn, {{self}});
73   return self;
74 }
75 
map_(const Tensor & self,const Tensor & other_,PyObject * fn)76 const Tensor& map_(const Tensor& self, const Tensor& other_, PyObject* fn) {
77   TORCH_CHECK_TYPE(
78       other_.options().type_equal(self.options()),
79       "map_: expected ",
80       self.toString(),
81       " for 'other' (got ",
82       other_.toString(),
83       ")");
84   if (self.is_meta()) {
85     return self; // Just skip
86   }
87   TORCH_CHECK_TYPE(
88       self.device().is_cpu(), "map_ is only implemented on CPU tensors");
89   c10::MaybeOwned<Tensor> other = expand_inplace(self, other_, "map_");
90   auto scalarType = self.scalar_type();
91   recursive_apply<2>(self.sizes(), scalarType, 0, fn, {{self, *other}});
92   return self;
93 }
94 
map2_(const Tensor & self,const Tensor & x_,const Tensor & y_,PyObject * fn)95 const Tensor& map2_(
96     const Tensor& self,
97     const Tensor& x_,
98     const Tensor& y_,
99     PyObject* fn) {
100   TORCH_CHECK_TYPE(
101       x_.options().type_equal(self.options()),
102       "map2_: expected ",
103       self.toString(),
104       " for argument 'x' (got ",
105       x_.toString(),
106       ")");
107   TORCH_CHECK_TYPE(
108       y_.options().type_equal(self.options()),
109       "map2_: expected ",
110       self.toString(),
111       " for argument 'y' (got ",
112       y_.toString(),
113       ")");
114   if (self.is_meta()) {
115     return self; // Just skip
116   }
117   TORCH_CHECK_TYPE(
118       (self.device().is_cpu() && x_.device().is_cpu() && y_.device().is_cpu()),
119       "map2_ is only implemented on CPU tensors");
120   auto others = expand_inplace(self, x_, y_, "map2_");
121   auto scalarType = self.scalar_type();
122   recursive_apply<3>(
123       self.sizes(),
124       scalarType,
125       0,
126       fn,
127       {{self, *std::get<0>(others), *std::get<1>(others)}});
128   return self;
129 }
130 
131 } // namespace torch::utils
132