xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/reduction.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/operators/reduction.h>
2 
3 using namespace torch::jit::tensorexpr;
4 
5 // Remove all indices from axes positions.
squeezeIndices(const ParameterList & indices,const std::vector<size_t> & axes)6 static std::vector<VarHandle> squeezeIndices(
7     const ParameterList& indices,
8     const std::vector<size_t>& axes) {
9   std::vector<VarHandle> indices_squeezed;
10   for (size_t dim = 0; dim < indices.size(); ++dim) {
11     if (!std::count(axes.begin(), axes.end(), dim)) {
12       indices_squeezed.push_back(indices[dim]);
13     }
14   }
15   return indices_squeezed;
16 }
17 
18 namespace torch::jit::tensorexpr {
19 
computeSum(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)20 Tensor computeSum(
21     const std::vector<ArgValue>& inputs,
22     const std::vector<ExprHandle>& outputShape,
23     const std::vector<ExprHandle>& outputStrides,
24     const std::optional<ScalarType>& outputType,
25     at::Device device) {
26   std::vector<size_t> axes;
27   bool keepdim = false;
28   // aten::sum takes the input tensor named self.
29   auto sizes = valueShape(inputs[0]);
30 
31   size_t rank = sizes.size();
32   if (inputs.size() > 2) {
33     if (auto emptyAxes = std::get_if<BufList>(&inputs[1])) {
34       // If dim-array is an empty list, it will appear as BufList instead of
35       // IntList, and hence we need a special handling for it.
36       // In that case, we need to sum over all axes.
37       TORCH_INTERNAL_ASSERT(emptyAxes->empty());
38       axes.resize(rank);
39       std::iota(axes.begin(), axes.end(), 0);
40     } else if (rank > 0) {
41       auto nodeAxes = std::get<IntList>(inputs[1]);
42       // Canonicalize axes: wrap around, sort and make unique.
43       for (auto axis : nodeAxes) {
44         axes.push_back(at::maybe_wrap_dim(axis, rank));
45       }
46       std::sort(axes.begin(), axes.end());
47       axes.erase(std::unique(axes.begin(), axes.end()), axes.end());
48     }
49     keepdim = std::get<bool>(inputs[2]);
50   } else {
51     axes.resize(rank);
52     std::iota(axes.begin(), axes.end(), 0);
53   }
54   // Axes go into reduction dimensions.
55   std::vector<ExprHandle> reductionDims;
56   reductionDims.reserve(rank);
57   for (size_t axis : axes) {
58     reductionDims.emplace_back(sizes[axis]);
59   }
60   std::vector<ExprHandle> outputDims;
61   // Output dimensions are the complement of axes. When keepdim is set, a
62   // one-sized dimension is inserted for each axis.
63   for (size_t dim = 0; dim < rank; ++dim) {
64     if (!std::count(axes.begin(), axes.end(), dim)) {
65       outputDims.emplace_back(sizes[dim]);
66     } else if (keepdim) {
67       outputDims.emplace_back(1);
68     }
69   }
70 
71   return Reduce(
72       "sum",
73       outputDims,
74       outputStrides,
75       Sum(),
76       [&](ParameterList& indices) {
77         // "Squeeze" out indices inserted when keepdim is set.
78         auto indices_squeezed =
79             keepdim ? squeezeIndices(indices, axes) : indices;
80         TORCH_INTERNAL_ASSERT(axes.size() <= indices_squeezed.size());
81         // Move innermost indices into axes positions:
82         //   1. Fill the outermost indices first.
83         //   2. Insert the innermost indices into the correct axis position,
84         //   displacing the outermost indices as needed.
85         std::vector<ExprHandle> indices_exprs;
86         size_t i = 0;
87         for (; i < indices_squeezed.size() - axes.size(); ++i) {
88           indices_exprs.push_back(indices_squeezed[i]);
89         }
90         for (auto axis : axes) {
91           indices_exprs.insert(
92               indices_exprs.begin() + axis, indices_squeezed[i]);
93           ++i;
94         }
95         auto indexed = tensorOrConstant(inputs[0], indices_exprs);
96         if (outputType) {
97           return Cast::make(ToDtype(*outputType), indexed);
98         } else {
99           return indexed;
100         }
101       },
102       reductionDims);
103 }
104 
computeMean(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)105 Tensor computeMean(
106     const std::vector<ArgValue>& inputs,
107     const std::vector<ExprHandle>& outputShape,
108     const std::vector<ExprHandle>& outputStrides,
109     const std::optional<ScalarType>& outputType,
110     at::Device device) {
111   Dtype dtype = kFloat;
112   if (outputType) {
113     dtype = Dtype(*outputType);
114   }
115   bool keepdim = false;
116   BufHandle ResultBuf("mean", outputShape, dtype);
117   BufHandle InputBuf = std::get<BufHandle>(inputs[0]);
118   std::vector<ExprHandle> extra_args;
119   if (inputs.size() > 2) {
120     keepdim = std::get<bool>(inputs[2]);
121   }
122 
123   if (auto mean_dims = std::get_if<IntList>(&inputs[1])) {
124     extra_args = c10::fmap<ExprHandle>(*mean_dims);
125   } else {
126     // When dims argument is not specified, reduce over all dimensions
127     for (int64_t idx = 0; idx < static_cast<int64_t>(InputBuf.ndim()); ++idx) {
128       extra_args.emplace_back(idx);
129     }
130   }
131   extra_args.push_back(LongImm::make(static_cast<int64_t>(keepdim)));
132   return Tensor(
133       ResultBuf.node(),
134       ExternalCall::make(ResultBuf, "nnc_aten_mean", {InputBuf}, extra_args));
135 }
136 
computeMax(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)137 Tensor computeMax(
138     const std::vector<ArgValue>& inputs,
139     const std::vector<ExprHandle>& outputShape,
140     const std::vector<ExprHandle>& outputStrides,
141     const std::optional<ScalarType>& outputType,
142     at::Device device) {
143   Dtype dtype = kFloat;
144   if (outputType) {
145     dtype = Dtype(*outputType);
146   }
147   BufHandle ResultBuf("max", outputShape, dtype);
148   BufHandle InputBuf = std::get<BufHandle>(inputs[0]);
149   std::vector<ExprHandle> max_dims_expr;
150   auto max_dim = std::get<int64_t>(inputs[1]);
151   auto keep_dim = std::get<bool>(inputs[2]);
152   return Tensor(
153       ResultBuf.node(),
154       ExternalCall::make(
155           ResultBuf,
156           "nnc_aten_max_red",
157           {InputBuf},
158           {max_dim, (int64_t)keep_dim}));
159 }
160 
computeAdaptiveAvgPool2d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)161 Tensor computeAdaptiveAvgPool2d(
162     const std::vector<ArgValue>& inputs,
163     const std::vector<ExprHandle>& outputShape,
164     const std::vector<ExprHandle>& outputStrides,
165     const std::optional<ScalarType>& outputType,
166     at::Device device) {
167   Dtype dtype = kFloat;
168   if (outputType) {
169     dtype = Dtype(*outputType);
170   }
171   BufHandle ResultBuf("adaptive_avgpool2d", outputShape, dtype);
172   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
173   auto out_size_param = std::get<IntList>(inputs[1]);
174   return Tensor(
175       ResultBuf.node(),
176       ExternalCall::make(
177           ResultBuf,
178           "nnc_aten_adaptive_avg_pool2d",
179           {std::get<BufHandle>(inputs[0])},
180           c10::fmap<ExprHandle>(out_size_param)));
181 }
182 
183 } // namespace torch::jit::tensorexpr
184