xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/pointwise.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
2 #include <torch/csrc/jit/tensorexpr/operators/pointwise.h>
3 
4 namespace torch::jit::tensorexpr {
5 
6 using namespace torch::jit::tensorexpr;
7 
computeSign(const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::optional<std::vector<ExprHandle>> & outputStrides)8 Tensor computeSign(
9     const std::vector<ArgValue>& inputValues,
10     const std::vector<ExprHandle>& outputShape,
11     const std::optional<std::vector<ExprHandle>>& outputStrides) {
12   return Compute(
13       "aten_sign", outputShape, outputStrides, [&](ParameterList& axes) {
14         std::vector<ExprHandle> indices(axes.begin(), axes.end());
15         std::vector<ExprHandle> inputs = {
16             tensorOrConstant(inputValues[0], indices)};
17         auto inp = inputs[0];
18         auto zero = ExprHandle(immLike(inp, 0.0f));
19         auto res = (zero < inp) - (inp < zero);
20         return promoteToDtype(res, inp.dtype().scalar_type());
21       });
22 }
23 
computeOneOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &)> & innerExpr,const int checkParamTypes)24 Tensor computeOneOperand(
25     const std::string& name,
26     const std::vector<ArgValue>& inputValues,
27     const std::vector<ExprHandle>& outputShape,
28     const std::vector<ExprHandle>& outputStrides,
29     const std::optional<ScalarType>& outputType,
30     const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
31     const int checkParamTypes) {
32   return Compute(
33       name,
34       outputShape,
35       outputStrides,
36       [inputValues, outputType, innerExpr, checkParamTypes](
37           const std::vector<VarHandle>& axes) {
38         std::vector<ExprHandle> indices(axes.begin(), axes.end());
39         std::vector<ExprHandle> inputs = {
40             tensorOrConstant(inputValues[0], indices)};
41         promoteInputs(inputs, checkParamTypes);
42         ExprHandle compute = innerExpr(inputs[0]);
43         return demoteOutput(compute, outputType);
44       });
45 }
46 
computeTwoOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &)> & innerExpr)47 Tensor computeTwoOperand(
48     const std::string& name,
49     const std::vector<ArgValue>& inputValues,
50     const std::vector<ExprHandle>& outputShape,
51     const std::vector<ExprHandle>& outputStrides,
52     const std::optional<ScalarType>& outputType,
53     const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
54         innerExpr) {
55   return Compute(
56       name,
57       outputShape,
58       outputStrides,
59       [inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
60         std::vector<ExprHandle> indices(axes.begin(), axes.end());
61         std::vector<ExprHandle> inputs = {
62             tensorOrConstant(inputValues[0], indices),
63             tensorOrConstant(inputValues[1], indices),
64         };
65 
66         promoteInputs(inputs);
67         ExprHandle compute = innerExpr(inputs[0], inputs[1]);
68         return demoteOutput(compute, outputType);
69       });
70 }
71 
computeTwoOperandWithAlpha(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &)> & innerExpr)72 Tensor computeTwoOperandWithAlpha(
73     const std::string& name,
74     const std::vector<ArgValue>& inputValues,
75     const std::vector<ExprHandle>& outputShape,
76     const std::vector<ExprHandle>& outputStrides,
77     const std::optional<ScalarType>& outputType,
78     const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
79         innerExpr) {
80   return Compute(
81       name,
82       outputShape,
83       outputStrides,
84       [inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
85         std::vector<ExprHandle> indices(axes.begin(), axes.end());
86         std::vector<ExprHandle> inputs = {
87             tensorOrConstant(inputValues[0], indices),
88             tensorOrConstant(inputValues[1], indices),
89             tensorOrConstant(inputValues[2], indices),
90         };
91 
92         promoteInputs(inputs);
93         ExprHandle compute = innerExpr(inputs[0], inputs[2] * inputs[1]);
94         return demoteOutput(compute, outputType);
95       });
96 }
97 
computeConditionWithTwoOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &,const ExprHandle &)> & innerExpr)98 Tensor computeConditionWithTwoOperand(
99     const std::string& name,
100     const std::vector<ArgValue>& inputValues,
101     const std::vector<ExprHandle>& outputShape,
102     const std::vector<ExprHandle>& outputStrides,
103     const std::optional<ScalarType>& outputType,
104     const std::function<
105         ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
106         innerExpr) {
107   return Compute(
108       name,
109       outputShape,
110       outputStrides,
111       [inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
112         std::vector<ExprHandle> indices(axes.begin(), axes.end());
113         std::vector<ExprHandle> inputs = {
114             tensorOrConstant(inputValues[1], indices),
115             tensorOrConstant(inputValues[2], indices),
116         };
117 
118         promoteInputs(inputs);
119         // First expr is the condition, which we don't promote
120         inputs.emplace(
121             inputs.begin(), tensorOrConstant(inputValues[0], indices));
122         ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]);
123         return demoteOutput(compute, outputType);
124       });
125 }
126 
computeThreeOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &,const ExprHandle &)> & innerExpr,bool promote_inputs)127 Tensor computeThreeOperand(
128     const std::string& name,
129     const std::vector<ArgValue>& inputValues,
130     const std::vector<ExprHandle>& outputShape,
131     const std::vector<ExprHandle>& outputStrides,
132     const std::optional<ScalarType>& outputType,
133     const std::function<
134         ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
135         innerExpr,
136     bool promote_inputs) {
137   return Compute(
138       name,
139       outputShape,
140       outputStrides,
141       [inputValues, outputType, innerExpr, promote_inputs](
142           const std::vector<VarHandle>& axes) {
143         std::vector<ExprHandle> indices(axes.begin(), axes.end());
144         std::vector<ExprHandle> inputs = {
145             tensorOrConstant(inputValues[0], indices),
146             tensorOrConstant(inputValues[1], indices),
147             tensorOrConstant(inputValues[2], indices),
148         };
149 
150         if (promote_inputs) {
151           promoteInputs(inputs);
152         }
153         ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]);
154         return demoteOutput(compute, outputType);
155       });
156 }
computeFourOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &,const ExprHandle &,const ExprHandle &)> & innerExpr)157 Tensor computeFourOperand(
158     const std::string& name,
159     const std::vector<ArgValue>& inputValues,
160     const std::vector<ExprHandle>& outputShape,
161     const std::vector<ExprHandle>& outputStrides,
162     const std::optional<ScalarType>& outputType,
163     const std::function<ExprHandle(
164         const ExprHandle&,
165         const ExprHandle&,
166         const ExprHandle&,
167         const ExprHandle&)>& innerExpr) {
168   return Compute(
169       name,
170       outputShape,
171       outputStrides,
172       [inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
173         std::vector<ExprHandle> indices(axes.begin(), axes.end());
174         std::vector<ExprHandle> inputs = {
175             tensorOrConstant(inputValues[0], indices),
176             tensorOrConstant(inputValues[1], indices),
177             tensorOrConstant(inputValues[2], indices),
178             tensorOrConstant(inputValues[3], indices),
179         };
180 
181         promoteInputs(inputs);
182         ExprHandle compute =
183             innerExpr(inputs[0], inputs[1], inputs[2], inputs[3]);
184         return demoteOutput(compute, outputType);
185       });
186 }
187 
computeNoop(const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)188 Tensor computeNoop(
189     const std::vector<ArgValue>& inputValues,
190     const std::vector<ExprHandle>& outputShape,
191     const std::vector<ExprHandle>& outputStrides,
192     const std::optional<ScalarType>& outputType,
193     at::Device device) {
194   return computeOneOperand(
195       "copy",
196       inputValues,
197       outputShape,
198       outputStrides,
199       outputType,
200       [](const ExprHandle& a) { return a; });
201 }
202 
computeScalar(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &)> & innerExpr)203 Tensor computeScalar(
204     const std::string& name,
205     const std::vector<ArgValue>& inputValues,
206     const std::vector<ExprHandle>& outputShape,
207     const std::vector<ExprHandle>& outputStrides,
208     const std::optional<ScalarType>& outputType,
209     const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
210         innerExpr) {
211   auto dt = Dtype(*outputType);
212   VarPtr let_var = alloc<Var>(name + "_var", dt);
213   std::vector<ExprHandle> inputs = {
214       scalarOrConstant(inputValues[0]), scalarOrConstant(inputValues[1])};
215   promoteInputs(inputs);
216   ExprHandle compute = innerExpr(inputs[0], inputs[1]);
217   StmtPtr let_stmt =
218       Let::make(VarHandle(let_var), demoteOutput(compute, outputType));
219   std::vector<ExprPtr> dims;
220   BufPtr buf = alloc<Buf>(let_var, dims, dt);
221   return Tensor(buf, let_stmt);
222 }
223 
224 } // namespace torch::jit::tensorexpr
225