1 #include <torch/csrc/jit/tensorexpr/tensor.h>
2
3 #include <c10/util/Logging.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/tensorexpr/reduction.h>
6
7 namespace torch::jit::tensorexpr {
8
constructStmt(const std::vector<VarPtr> & args,const ExprPtr & body,const std::vector<ExprPtr> & reduce_dims,const std::vector<VarPtr> & reduce_args) const9 StmtPtr Tensor::constructStmt(
10 const std::vector<VarPtr>& args,
11 const ExprPtr& body,
12 const std::vector<ExprPtr>& reduce_dims,
13 const std::vector<VarPtr>& reduce_args) const {
14 std::vector<ExprPtr> indices(args.begin(), args.end());
15
16 size_t ndim = buf()->ndim();
17 size_t reduce_ndim = reduce_dims.size();
18 auto reduce_op = to<ReduceOp>(body);
19 auto acc_buf = reduce_ndim > 0 ? reduce_op->getAccBuf() : nullptr;
20
21 StmtPtr s = alloc<Store>(buf_, indices, body);
22 if (reduce_ndim > 0) {
23 TORCH_INTERNAL_ASSERT(reduce_op != nullptr);
24 if (acc_buf != nullptr) {
25 auto reducer = reduce_op->reducer();
26 std::vector<ExprPtr> output_args(args.begin(), args.end());
27 ExprPtr new_reduce_op = reducer(
28 to<Buf>(acc_buf),
29 alloc<Cast>(acc_buf->dtype(), reduce_op->getRiOperand()),
30 output_args,
31 reduce_args);
32 new_reduce_op->set_dtype(acc_buf->dtype());
33 s = alloc<Store>(to<Buf>(acc_buf), indices, new_reduce_op);
34 }
35 }
36
37 if (ndim == 0 && reduce_ndim == 0) {
38 return s;
39 }
40
41 if (reduce_ndim > 0) {
42 TORCH_INTERNAL_ASSERT(reduce_op != nullptr);
43
44 for (const auto i : c10::irange(reduce_ndim)) {
45 // Going in reverse order: from innermost loop to the outermost
46 size_t dim_index = reduce_ndim - i - 1;
47 auto const& dim = reduce_dims[dim_index];
48 s = alloc<For>(reduce_args[dim_index], immLike(dim, 0), dim, s);
49 }
50 s = alloc<Block>(std::vector<StmtPtr>({s}));
51
52 BufPtr init_buf = acc_buf ? to<Buf>(acc_buf) : buf();
53 ExprPtr init_expr =
54 acc_buf ? to<Buf>(acc_buf)->initializer() : buf()->initializer();
55 if (init_expr) {
56 StorePtr init_stmt = alloc<Store>(init_buf, indices, init_expr);
57 to<Block>(s)->prepend_stmt(init_stmt);
58 }
59
60 if (acc_buf != nullptr) {
61 LoadPtr load_acc = alloc<Load>(acc_buf, indices);
62 auto cast = alloc<Cast>(buf()->dtype(), load_acc);
63 StorePtr post_stmt = alloc<Store>(buf(), indices, cast);
64 to<Block>(s)->append_stmt(post_stmt);
65 }
66 }
67
68 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
69 buf_->is_contiguous() ||
70 buf_->is_contiguous(at::MemoryFormat::ChannelsLast) ||
71 buf_->is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
72 buf_->is_channels_last_1d_contiguous());
73
74 auto loop_order_fn = [&]() {
75 std::vector<int32_t> loop_order;
76 if (buf_->is_contiguous()) {
77 for (int32_t i = args.size() - 1; i >= 0; i--) {
78 loop_order.push_back(i);
79 }
80 } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast)) {
81 loop_order = {1, 3, 2, 0};
82 } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
83 loop_order = {1, 4, 3, 2, 0};
84 } else {
85 loop_order = {1, 2, 0};
86 }
87
88 return loop_order;
89 };
90
91 auto loop_order = loop_order_fn();
92 for (auto dim_index : loop_order) {
93 auto const& dim = buf()->dim(dim_index);
94 s = alloc<For>(args[dim_index], immLike(dim, 0), dim, s);
95 }
96 return s;
97 }
98
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const std::vector<VarHandle> &)> & body_func)99 Tensor Compute(
100 const std::string& name,
101 const std::vector<ExprHandle>& dims,
102 const std::optional<std::vector<ExprHandle>>& strides,
103 const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
104 std::vector<VarHandle> args = create_index_vars(dims);
105 ExprHandle body = body_func(args);
106 BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
107 return Tensor(buf, args, body);
108 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const std::vector<VarHandle> &)> & body_func)109 Tensor Compute(
110 const std::string& name,
111 const std::vector<ExprHandle>& dims,
112 const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
113 return Compute(name, dims, std::nullopt, body_func);
114 }
115
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const VarHandle &)> & body_func)116 Tensor Compute(
117 const std::string& name,
118 const std::vector<ExprHandle>& dims,
119 const std::optional<std::vector<ExprHandle>>& strides,
120 const std::function<ExprHandle(const VarHandle&)>& body_func) {
121 if (dims.size() != 1) {
122 throw malformed_input("mismatch between body and arg size (1)");
123 }
124
125 std::vector<VarHandle> args = create_index_vars(dims);
126 ExprHandle body = body_func(args[0]);
127 BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
128 return Tensor(buf, args, body);
129 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const VarHandle &)> & body_func)130 Tensor Compute(
131 const std::string& name,
132 const std::vector<ExprHandle>& dims,
133 const std::function<ExprHandle(const VarHandle&)>& body_func) {
134 return Compute(name, dims, std::nullopt, body_func);
135 }
136
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const VarHandle &,const VarHandle &)> & body_func)137 Tensor Compute(
138 const std::string& name,
139 const std::vector<ExprHandle>& dims,
140 const std::optional<std::vector<ExprHandle>>& strides,
141 const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
142 body_func) {
143 if (dims.size() != 2) {
144 throw malformed_input("mismatch between body and arg size (2)");
145 }
146 std::vector<VarHandle> args = create_index_vars(dims);
147 ExprHandle body = body_func(args[0], args[1]);
148 BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
149 return Tensor(buf, args, body);
150 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const VarHandle &,const VarHandle &)> & body_func)151 Tensor Compute(
152 const std::string& name,
153 const std::vector<ExprHandle>& dims,
154 const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
155 body_func) {
156 return Compute(name, dims, std::nullopt, body_func);
157 }
158
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &)> & body_func)159 Tensor Compute(
160 const std::string& name,
161 const std::vector<ExprHandle>& dims,
162 const std::optional<std::vector<ExprHandle>>& strides,
163 const std::function<
164 ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
165 body_func) {
166 if (dims.size() != 3) {
167 throw malformed_input("mismatch between body and arg size (3)");
168 }
169 std::vector<VarHandle> args = create_index_vars(dims);
170 ExprHandle body = body_func(args[0], args[1], args[2]);
171 BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
172 return Tensor(buf, args, body);
173 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &)> & body_func)174 Tensor Compute(
175 const std::string& name,
176 const std::vector<ExprHandle>& dims,
177 const std::function<
178 ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
179 body_func) {
180 return Compute(name, dims, std::nullopt, body_func);
181 }
182
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &,const VarHandle &)> & body_func)183 Tensor Compute(
184 const std::string& name,
185 const std::vector<ExprHandle>& dims,
186 const std::optional<std::vector<ExprHandle>>& strides,
187 const std::function<ExprHandle(
188 const VarHandle&,
189 const VarHandle&,
190 const VarHandle&,
191 const VarHandle&)>& body_func) {
192 if (dims.size() != 4) {
193 throw malformed_input("mismatch between body and arg size (4)");
194 }
195 std::vector<VarHandle> args = create_index_vars(dims);
196 ExprHandle body = body_func(args[0], args[1], args[2], args[3]);
197 BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
198 return Tensor(buf, args, body);
199 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &,const VarHandle &)> & body_func)200 Tensor Compute(
201 const std::string& name,
202 const std::vector<ExprHandle>& dims,
203 const std::function<ExprHandle(
204 const VarHandle&,
205 const VarHandle&,
206 const VarHandle&,
207 const VarHandle&)>& body_func) {
208 return Compute(name, dims, std::nullopt, body_func);
209 }
210
Reduce(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const BufHandle & buffer,const std::vector<ExprHandle> & reduce_dims)211 Tensor Reduce(
212 const std::string& name,
213 const std::vector<ExprHandle>& dims,
214 const std::optional<std::vector<ExprHandle>>& strides,
215 const Reducer& reducer,
216 const BufHandle& buffer,
217 const std::vector<ExprHandle>& reduce_dims) {
218 return Reduce(
219 name,
220 dims,
221 strides,
222 reducer,
223 [&](ParameterList& p) { return buffer.load(p); },
224 reduce_dims);
225 }
Reduce(const std::string & name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const BufHandle & buffer,const std::vector<ExprHandle> & reduce_dims)226 Tensor Reduce(
227 const std::string& name,
228 const std::vector<ExprHandle>& dims,
229 const Reducer& reducer,
230 const BufHandle& buffer,
231 const std::vector<ExprHandle>& reduce_dims) {
232 return Reduce(name, dims, std::nullopt, reducer, buffer, reduce_dims);
233 }
234
Reduce(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const Tensor & tensor,const std::vector<ExprHandle> & reduce_dims)235 Tensor Reduce(
236 const std::string& name,
237 const std::vector<ExprHandle>& dims,
238 const std::optional<std::vector<ExprHandle>>& strides,
239 const Reducer& reducer,
240 const Tensor& tensor,
241 const std::vector<ExprHandle>& reduce_dims) {
242 return Reduce(
243 name,
244 dims,
245 strides,
246 reducer,
247 [&](ParameterList& p) { return tensor.load(p); },
248 reduce_dims);
249 }
Reduce(const std::string & name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const Tensor & tensor,const std::vector<ExprHandle> & reduce_dims)250 Tensor Reduce(
251 const std::string& name,
252 const std::vector<ExprHandle>& dims,
253 const Reducer& reducer,
254 const Tensor& tensor,
255 const std::vector<ExprHandle>& reduce_dims) {
256 return Reduce(name, dims, std::nullopt, reducer, tensor, reduce_dims);
257 }
258
259 } // namespace torch::jit::tensorexpr
260