1 #include <torch/csrc/jit/tensorexpr/operators/reduction.h>
2
3 using namespace torch::jit::tensorexpr;
4
5 // Remove all indices from axes positions.
squeezeIndices(const ParameterList & indices,const std::vector<size_t> & axes)6 static std::vector<VarHandle> squeezeIndices(
7 const ParameterList& indices,
8 const std::vector<size_t>& axes) {
9 std::vector<VarHandle> indices_squeezed;
10 for (size_t dim = 0; dim < indices.size(); ++dim) {
11 if (!std::count(axes.begin(), axes.end(), dim)) {
12 indices_squeezed.push_back(indices[dim]);
13 }
14 }
15 return indices_squeezed;
16 }
17
18 namespace torch::jit::tensorexpr {
19
computeSum(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)20 Tensor computeSum(
21 const std::vector<ArgValue>& inputs,
22 const std::vector<ExprHandle>& outputShape,
23 const std::vector<ExprHandle>& outputStrides,
24 const std::optional<ScalarType>& outputType,
25 at::Device device) {
26 std::vector<size_t> axes;
27 bool keepdim = false;
28 // aten::sum takes the input tensor named self.
29 auto sizes = valueShape(inputs[0]);
30
31 size_t rank = sizes.size();
32 if (inputs.size() > 2) {
33 if (auto emptyAxes = std::get_if<BufList>(&inputs[1])) {
34 // If dim-array is an empty list, it will appear as BufList instead of
35 // IntList, and hence we need a special handling for it.
36 // In that case, we need to sum over all axes.
37 TORCH_INTERNAL_ASSERT(emptyAxes->empty());
38 axes.resize(rank);
39 std::iota(axes.begin(), axes.end(), 0);
40 } else if (rank > 0) {
41 auto nodeAxes = std::get<IntList>(inputs[1]);
42 // Canonicalize axes: wrap around, sort and make unique.
43 for (auto axis : nodeAxes) {
44 axes.push_back(at::maybe_wrap_dim(axis, rank));
45 }
46 std::sort(axes.begin(), axes.end());
47 axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
48 }
49 keepdim = std::get<bool>(inputs[2]);
50 } else {
51 axes.resize(rank);
52 std::iota(axes.begin(), axes.end(), 0);
53 }
54 // Axes go into reduction dimensions.
55 std::vector<ExprHandle> reductionDims;
56 reductionDims.reserve(rank);
57 for (size_t axis : axes) {
58 reductionDims.emplace_back(sizes[axis]);
59 }
60 std::vector<ExprHandle> outputDims;
61 // Output dimensions are the complement of axes. When keepdim is set, a
62 // one-sized dimension is inserted for each axis.
63 for (size_t dim = 0; dim < rank; ++dim) {
64 if (!std::count(axes.begin(), axes.end(), dim)) {
65 outputDims.emplace_back(sizes[dim]);
66 } else if (keepdim) {
67 outputDims.emplace_back(1);
68 }
69 }
70
71 return Reduce(
72 "sum",
73 outputDims,
74 outputStrides,
75 Sum(),
76 [&](ParameterList& indices) {
77 // "Squeeze" out indices inserted when keepdim is set.
78 auto indices_squeezed =
79 keepdim ? squeezeIndices(indices, axes) : indices;
80 TORCH_INTERNAL_ASSERT(axes.size() <= indices_squeezed.size());
81 // Move innermost indices into axes positions:
82 // 1. Fill the outermost indices first.
83 // 2. Insert the innermost indices into the correct axis position,
84 // displacing the outermost indices as needed.
85 std::vector<ExprHandle> indices_exprs;
86 size_t i = 0;
87 for (; i < indices_squeezed.size() - axes.size(); ++i) {
88 indices_exprs.push_back(indices_squeezed[i]);
89 }
90 for (auto axis : axes) {
91 indices_exprs.insert(
92 indices_exprs.begin() + axis, indices_squeezed[i]);
93 ++i;
94 }
95 auto indexed = tensorOrConstant(inputs[0], indices_exprs);
96 if (outputType) {
97 return Cast::make(ToDtype(*outputType), indexed);
98 } else {
99 return indexed;
100 }
101 },
102 reductionDims);
103 }
104
computeMean(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)105 Tensor computeMean(
106 const std::vector<ArgValue>& inputs,
107 const std::vector<ExprHandle>& outputShape,
108 const std::vector<ExprHandle>& outputStrides,
109 const std::optional<ScalarType>& outputType,
110 at::Device device) {
111 Dtype dtype = kFloat;
112 if (outputType) {
113 dtype = Dtype(*outputType);
114 }
115 bool keepdim = false;
116 BufHandle ResultBuf("mean", outputShape, dtype);
117 BufHandle InputBuf = std::get<BufHandle>(inputs[0]);
118 std::vector<ExprHandle> extra_args;
119 if (inputs.size() > 2) {
120 keepdim = std::get<bool>(inputs[2]);
121 }
122
123 if (auto mean_dims = std::get_if<IntList>(&inputs[1])) {
124 extra_args = c10::fmap<ExprHandle>(*mean_dims);
125 } else {
126 // When dims argument is not specified, reduce over all dimensions
127 for (int64_t idx = 0; idx < static_cast<int64_t>(InputBuf.ndim()); ++idx) {
128 extra_args.emplace_back(idx);
129 }
130 }
131 extra_args.push_back(LongImm::make(static_cast<int64_t>(keepdim)));
132 return Tensor(
133 ResultBuf.node(),
134 ExternalCall::make(ResultBuf, "nnc_aten_mean", {InputBuf}, extra_args));
135 }
136
computeMax(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)137 Tensor computeMax(
138 const std::vector<ArgValue>& inputs,
139 const std::vector<ExprHandle>& outputShape,
140 const std::vector<ExprHandle>& outputStrides,
141 const std::optional<ScalarType>& outputType,
142 at::Device device) {
143 Dtype dtype = kFloat;
144 if (outputType) {
145 dtype = Dtype(*outputType);
146 }
147 BufHandle ResultBuf("max", outputShape, dtype);
148 BufHandle InputBuf = std::get<BufHandle>(inputs[0]);
149 std::vector<ExprHandle> max_dims_expr;
150 auto max_dim = std::get<int64_t>(inputs[1]);
151 auto keep_dim = std::get<bool>(inputs[2]);
152 return Tensor(
153 ResultBuf.node(),
154 ExternalCall::make(
155 ResultBuf,
156 "nnc_aten_max_red",
157 {InputBuf},
158 {max_dim, (int64_t)keep_dim}));
159 }
160
computeAdaptiveAvgPool2d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)161 Tensor computeAdaptiveAvgPool2d(
162 const std::vector<ArgValue>& inputs,
163 const std::vector<ExprHandle>& outputShape,
164 const std::vector<ExprHandle>& outputStrides,
165 const std::optional<ScalarType>& outputType,
166 at::Device device) {
167 Dtype dtype = kFloat;
168 if (outputType) {
169 dtype = Dtype(*outputType);
170 }
171 BufHandle ResultBuf("adaptive_avgpool2d", outputShape, dtype);
172 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
173 auto out_size_param = std::get<IntList>(inputs[1]);
174 return Tensor(
175 ResultBuf.node(),
176 ExternalCall::make(
177 ResultBuf,
178 "nnc_aten_adaptive_avg_pool2d",
179 {std::get<BufHandle>(inputs[0])},
180 c10::fmap<ExprHandle>(out_size_param)));
181 }
182
183 } // namespace torch::jit::tensorexpr
184