1 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
2 #include <torch/csrc/jit/tensorexpr/operators/pointwise.h>
3
4 namespace torch::jit::tensorexpr {
5
6 using namespace torch::jit::tensorexpr;
7
computeSign(const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::optional<std::vector<ExprHandle>> & outputStrides)8 Tensor computeSign(
9 const std::vector<ArgValue>& inputValues,
10 const std::vector<ExprHandle>& outputShape,
11 const std::optional<std::vector<ExprHandle>>& outputStrides) {
12 return Compute(
13 "aten_sign", outputShape, outputStrides, [&](ParameterList& axes) {
14 std::vector<ExprHandle> indices(axes.begin(), axes.end());
15 std::vector<ExprHandle> inputs = {
16 tensorOrConstant(inputValues[0], indices)};
17 auto inp = inputs[0];
18 auto zero = ExprHandle(immLike(inp, 0.0f));
19 auto res = (zero < inp) - (inp < zero);
20 return promoteToDtype(res, inp.dtype().scalar_type());
21 });
22 }
23
computeOneOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &)> & innerExpr,const int checkParamTypes)24 Tensor computeOneOperand(
25 const std::string& name,
26 const std::vector<ArgValue>& inputValues,
27 const std::vector<ExprHandle>& outputShape,
28 const std::vector<ExprHandle>& outputStrides,
29 const std::optional<ScalarType>& outputType,
30 const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
31 const int checkParamTypes) {
32 return Compute(
33 name,
34 outputShape,
35 outputStrides,
36 [inputValues, outputType, innerExpr, checkParamTypes](
37 const std::vector<VarHandle>& axes) {
38 std::vector<ExprHandle> indices(axes.begin(), axes.end());
39 std::vector<ExprHandle> inputs = {
40 tensorOrConstant(inputValues[0], indices)};
41 promoteInputs(inputs, checkParamTypes);
42 ExprHandle compute = innerExpr(inputs[0]);
43 return demoteOutput(compute, outputType);
44 });
45 }
46
computeTwoOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &)> & innerExpr)47 Tensor computeTwoOperand(
48 const std::string& name,
49 const std::vector<ArgValue>& inputValues,
50 const std::vector<ExprHandle>& outputShape,
51 const std::vector<ExprHandle>& outputStrides,
52 const std::optional<ScalarType>& outputType,
53 const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
54 innerExpr) {
55 return Compute(
56 name,
57 outputShape,
58 outputStrides,
59 [inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
60 std::vector<ExprHandle> indices(axes.begin(), axes.end());
61 std::vector<ExprHandle> inputs = {
62 tensorOrConstant(inputValues[0], indices),
63 tensorOrConstant(inputValues[1], indices),
64 };
65
66 promoteInputs(inputs);
67 ExprHandle compute = innerExpr(inputs[0], inputs[1]);
68 return demoteOutput(compute, outputType);
69 });
70 }
71
computeTwoOperandWithAlpha(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &)> & innerExpr)72 Tensor computeTwoOperandWithAlpha(
73 const std::string& name,
74 const std::vector<ArgValue>& inputValues,
75 const std::vector<ExprHandle>& outputShape,
76 const std::vector<ExprHandle>& outputStrides,
77 const std::optional<ScalarType>& outputType,
78 const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
79 innerExpr) {
80 return Compute(
81 name,
82 outputShape,
83 outputStrides,
84 [inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
85 std::vector<ExprHandle> indices(axes.begin(), axes.end());
86 std::vector<ExprHandle> inputs = {
87 tensorOrConstant(inputValues[0], indices),
88 tensorOrConstant(inputValues[1], indices),
89 tensorOrConstant(inputValues[2], indices),
90 };
91
92 promoteInputs(inputs);
93 ExprHandle compute = innerExpr(inputs[0], inputs[2] * inputs[1]);
94 return demoteOutput(compute, outputType);
95 });
96 }
97
computeConditionWithTwoOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &,const ExprHandle &)> & innerExpr)98 Tensor computeConditionWithTwoOperand(
99 const std::string& name,
100 const std::vector<ArgValue>& inputValues,
101 const std::vector<ExprHandle>& outputShape,
102 const std::vector<ExprHandle>& outputStrides,
103 const std::optional<ScalarType>& outputType,
104 const std::function<
105 ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
106 innerExpr) {
107 return Compute(
108 name,
109 outputShape,
110 outputStrides,
111 [inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
112 std::vector<ExprHandle> indices(axes.begin(), axes.end());
113 std::vector<ExprHandle> inputs = {
114 tensorOrConstant(inputValues[1], indices),
115 tensorOrConstant(inputValues[2], indices),
116 };
117
118 promoteInputs(inputs);
119 // First expr is the condition, which we don't promote
120 inputs.emplace(
121 inputs.begin(), tensorOrConstant(inputValues[0], indices));
122 ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]);
123 return demoteOutput(compute, outputType);
124 });
125 }
126
computeThreeOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &,const ExprHandle &)> & innerExpr,bool promote_inputs)127 Tensor computeThreeOperand(
128 const std::string& name,
129 const std::vector<ArgValue>& inputValues,
130 const std::vector<ExprHandle>& outputShape,
131 const std::vector<ExprHandle>& outputStrides,
132 const std::optional<ScalarType>& outputType,
133 const std::function<
134 ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
135 innerExpr,
136 bool promote_inputs) {
137 return Compute(
138 name,
139 outputShape,
140 outputStrides,
141 [inputValues, outputType, innerExpr, promote_inputs](
142 const std::vector<VarHandle>& axes) {
143 std::vector<ExprHandle> indices(axes.begin(), axes.end());
144 std::vector<ExprHandle> inputs = {
145 tensorOrConstant(inputValues[0], indices),
146 tensorOrConstant(inputValues[1], indices),
147 tensorOrConstant(inputValues[2], indices),
148 };
149
150 if (promote_inputs) {
151 promoteInputs(inputs);
152 }
153 ExprHandle compute = innerExpr(inputs[0], inputs[1], inputs[2]);
154 return demoteOutput(compute, outputType);
155 });
156 }
computeFourOperand(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &,const ExprHandle &,const ExprHandle &)> & innerExpr)157 Tensor computeFourOperand(
158 const std::string& name,
159 const std::vector<ArgValue>& inputValues,
160 const std::vector<ExprHandle>& outputShape,
161 const std::vector<ExprHandle>& outputStrides,
162 const std::optional<ScalarType>& outputType,
163 const std::function<ExprHandle(
164 const ExprHandle&,
165 const ExprHandle&,
166 const ExprHandle&,
167 const ExprHandle&)>& innerExpr) {
168 return Compute(
169 name,
170 outputShape,
171 outputStrides,
172 [inputValues, outputType, innerExpr](const std::vector<VarHandle>& axes) {
173 std::vector<ExprHandle> indices(axes.begin(), axes.end());
174 std::vector<ExprHandle> inputs = {
175 tensorOrConstant(inputValues[0], indices),
176 tensorOrConstant(inputValues[1], indices),
177 tensorOrConstant(inputValues[2], indices),
178 tensorOrConstant(inputValues[3], indices),
179 };
180
181 promoteInputs(inputs);
182 ExprHandle compute =
183 innerExpr(inputs[0], inputs[1], inputs[2], inputs[3]);
184 return demoteOutput(compute, outputType);
185 });
186 }
187
computeNoop(const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)188 Tensor computeNoop(
189 const std::vector<ArgValue>& inputValues,
190 const std::vector<ExprHandle>& outputShape,
191 const std::vector<ExprHandle>& outputStrides,
192 const std::optional<ScalarType>& outputType,
193 at::Device device) {
194 return computeOneOperand(
195 "copy",
196 inputValues,
197 outputShape,
198 outputStrides,
199 outputType,
200 [](const ExprHandle& a) { return a; });
201 }
202
computeScalar(const std::string & name,const std::vector<ArgValue> & inputValues,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,const std::function<ExprHandle (const ExprHandle &,const ExprHandle &)> & innerExpr)203 Tensor computeScalar(
204 const std::string& name,
205 const std::vector<ArgValue>& inputValues,
206 const std::vector<ExprHandle>& outputShape,
207 const std::vector<ExprHandle>& outputStrides,
208 const std::optional<ScalarType>& outputType,
209 const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
210 innerExpr) {
211 auto dt = Dtype(*outputType);
212 VarPtr let_var = alloc<Var>(name + "_var", dt);
213 std::vector<ExprHandle> inputs = {
214 scalarOrConstant(inputValues[0]), scalarOrConstant(inputValues[1])};
215 promoteInputs(inputs);
216 ExprHandle compute = innerExpr(inputs[0], inputs[1]);
217 StmtPtr let_stmt =
218 Let::make(VarHandle(let_var), demoteOutput(compute, outputType));
219 std::vector<ExprPtr> dims;
220 BufPtr buf = alloc<Buf>(let_var, dims, dt);
221 return Tensor(buf, let_stmt);
222 }
223
224 } // namespace torch::jit::tensorexpr
225