1 #include <torch/csrc/utils/tensor_apply.h>
2
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/TensorUtils.h>
5 #include <c10/util/irange.h>
6
7 #include <torch/csrc/Exceptions.h>
8 #include <torch/csrc/utils/python_numbers.h>
9 #include <torch/csrc/utils/python_scalars.h>
10
11 using namespace at;
12
13 namespace torch::utils {
14
15 struct StridedData {
StridedDatatorch::utils::StridedData16 StridedData(const Tensor& tensor)
17 : data(tensor.data_ptr()),
18 strides(tensor.strides()),
19 elementSize(tensor.element_size()) {}
20
21 void* data;
22 IntArrayRef strides;
23 int64_t elementSize;
24
steptorch::utils::StridedData25 void step(int dim) {
26 data = (char*)data + (strides[dim] * elementSize);
27 }
28 };
29
30 template <size_t N>
recursive_apply(IntArrayRef sizes,ScalarType scalarType,int64_t dim,PyObject * fn,std::array<StridedData,N> strided_data)31 static void recursive_apply(
32 IntArrayRef sizes,
33 ScalarType scalarType,
34 int64_t dim,
35 PyObject* fn,
36 std::array<StridedData, N> strided_data) {
37 int64_t ndim = static_cast<int64_t>(sizes.size());
38 if (dim == ndim) {
39 auto args = THPObjectPtr(PyTuple_New(N));
40 if (!args)
41 throw python_error();
42 for (const auto i : c10::irange(N)) {
43 PyObject* arg = load_scalar(strided_data[i].data, scalarType);
44 if (!arg)
45 throw python_error();
46 PyTuple_SET_ITEM(args.get(), i, arg);
47 }
48 auto ret = THPObjectPtr(PyObject_CallObject(fn, args.get()));
49 if (!ret)
50 throw python_error();
51 store_scalar(strided_data[0].data, scalarType, ret.get());
52 return;
53 }
54
55 auto n = sizes[dim];
56 for (const auto i : c10::irange(n)) {
57 (void)i; // Suppress unused variable warning
58 recursive_apply(sizes, scalarType, dim + 1, fn, strided_data);
59 for (auto& td : strided_data) {
60 td.step(dim);
61 }
62 }
63 }
64
apply_(const Tensor & self,PyObject * fn)65 const Tensor& apply_(const Tensor& self, PyObject* fn) {
66 if (self.is_meta()) {
67 return self; // Just skip
68 }
69 TORCH_CHECK_TYPE(
70 self.device().is_cpu(), "apply_ is only implemented on CPU tensors");
71 auto scalarType = self.scalar_type();
72 recursive_apply<1>(self.sizes(), scalarType, 0, fn, {{self}});
73 return self;
74 }
75
map_(const Tensor & self,const Tensor & other_,PyObject * fn)76 const Tensor& map_(const Tensor& self, const Tensor& other_, PyObject* fn) {
77 TORCH_CHECK_TYPE(
78 other_.options().type_equal(self.options()),
79 "map_: expected ",
80 self.toString(),
81 " for 'other' (got ",
82 other_.toString(),
83 ")");
84 if (self.is_meta()) {
85 return self; // Just skip
86 }
87 TORCH_CHECK_TYPE(
88 self.device().is_cpu(), "map_ is only implemented on CPU tensors");
89 c10::MaybeOwned<Tensor> other = expand_inplace(self, other_, "map_");
90 auto scalarType = self.scalar_type();
91 recursive_apply<2>(self.sizes(), scalarType, 0, fn, {{self, *other}});
92 return self;
93 }
94
map2_(const Tensor & self,const Tensor & x_,const Tensor & y_,PyObject * fn)95 const Tensor& map2_(
96 const Tensor& self,
97 const Tensor& x_,
98 const Tensor& y_,
99 PyObject* fn) {
100 TORCH_CHECK_TYPE(
101 x_.options().type_equal(self.options()),
102 "map2_: expected ",
103 self.toString(),
104 " for argument 'x' (got ",
105 x_.toString(),
106 ")");
107 TORCH_CHECK_TYPE(
108 y_.options().type_equal(self.options()),
109 "map2_: expected ",
110 self.toString(),
111 " for argument 'y' (got ",
112 y_.toString(),
113 ")");
114 if (self.is_meta()) {
115 return self; // Just skip
116 }
117 TORCH_CHECK_TYPE(
118 (self.device().is_cpu() && x_.device().is_cpu() && y_.device().is_cpu()),
119 "map2_ is only implemented on CPU tensors");
120 auto others = expand_inplace(self, x_, y_, "map2_");
121 auto scalarType = self.scalar_type();
122 recursive_apply<3>(
123 self.sizes(),
124 scalarType,
125 0,
126 fn,
127 {{self, *std::get<0>(others), *std::get<1>(others)}});
128 return self;
129 }
130
131 } // namespace torch::utils
132