xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/tensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/tensor.h>
2 
3 #include <c10/util/Logging.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/tensorexpr/reduction.h>
6 
7 namespace torch::jit::tensorexpr {
8 
constructStmt(const std::vector<VarPtr> & args,const ExprPtr & body,const std::vector<ExprPtr> & reduce_dims,const std::vector<VarPtr> & reduce_args) const9 StmtPtr Tensor::constructStmt(
10     const std::vector<VarPtr>& args,
11     const ExprPtr& body,
12     const std::vector<ExprPtr>& reduce_dims,
13     const std::vector<VarPtr>& reduce_args) const {
14   std::vector<ExprPtr> indices(args.begin(), args.end());
15 
16   size_t ndim = buf()->ndim();
17   size_t reduce_ndim = reduce_dims.size();
18   auto reduce_op = to<ReduceOp>(body);
19   auto acc_buf = reduce_ndim > 0 ? reduce_op->getAccBuf() : nullptr;
20 
21   StmtPtr s = alloc<Store>(buf_, indices, body);
22   if (reduce_ndim > 0) {
23     TORCH_INTERNAL_ASSERT(reduce_op != nullptr);
24     if (acc_buf != nullptr) {
25       auto reducer = reduce_op->reducer();
26       std::vector<ExprPtr> output_args(args.begin(), args.end());
27       ExprPtr new_reduce_op = reducer(
28           to<Buf>(acc_buf),
29           alloc<Cast>(acc_buf->dtype(), reduce_op->getRiOperand()),
30           output_args,
31           reduce_args);
32       new_reduce_op->set_dtype(acc_buf->dtype());
33       s = alloc<Store>(to<Buf>(acc_buf), indices, new_reduce_op);
34     }
35   }
36 
37   if (ndim == 0 && reduce_ndim == 0) {
38     return s;
39   }
40 
41   if (reduce_ndim > 0) {
42     TORCH_INTERNAL_ASSERT(reduce_op != nullptr);
43 
44     for (const auto i : c10::irange(reduce_ndim)) {
45       // Going in reverse order: from innermost loop to the outermost
46       size_t dim_index = reduce_ndim - i - 1;
47       auto const& dim = reduce_dims[dim_index];
48       s = alloc<For>(reduce_args[dim_index], immLike(dim, 0), dim, s);
49     }
50     s = alloc<Block>(std::vector<StmtPtr>({s}));
51 
52     BufPtr init_buf = acc_buf ? to<Buf>(acc_buf) : buf();
53     ExprPtr init_expr =
54         acc_buf ? to<Buf>(acc_buf)->initializer() : buf()->initializer();
55     if (init_expr) {
56       StorePtr init_stmt = alloc<Store>(init_buf, indices, init_expr);
57       to<Block>(s)->prepend_stmt(init_stmt);
58     }
59 
60     if (acc_buf != nullptr) {
61       LoadPtr load_acc = alloc<Load>(acc_buf, indices);
62       auto cast = alloc<Cast>(buf()->dtype(), load_acc);
63       StorePtr post_stmt = alloc<Store>(buf(), indices, cast);
64       to<Block>(s)->append_stmt(post_stmt);
65     }
66   }
67 
68   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
69       buf_->is_contiguous() ||
70       buf_->is_contiguous(at::MemoryFormat::ChannelsLast) ||
71       buf_->is_contiguous(at::MemoryFormat::ChannelsLast3d) ||
72       buf_->is_channels_last_1d_contiguous());
73 
74   auto loop_order_fn = [&]() {
75     std::vector<int32_t> loop_order;
76     if (buf_->is_contiguous()) {
77       for (int32_t i = args.size() - 1; i >= 0; i--) {
78         loop_order.push_back(i);
79       }
80     } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast)) {
81       loop_order = {1, 3, 2, 0};
82     } else if (buf_->is_contiguous(c10::MemoryFormat::ChannelsLast3d)) {
83       loop_order = {1, 4, 3, 2, 0};
84     } else {
85       loop_order = {1, 2, 0};
86     }
87 
88     return loop_order;
89   };
90 
91   auto loop_order = loop_order_fn();
92   for (auto dim_index : loop_order) {
93     auto const& dim = buf()->dim(dim_index);
94     s = alloc<For>(args[dim_index], immLike(dim, 0), dim, s);
95   }
96   return s;
97 }
98 
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const std::vector<VarHandle> &)> & body_func)99 Tensor Compute(
100     const std::string& name,
101     const std::vector<ExprHandle>& dims,
102     const std::optional<std::vector<ExprHandle>>& strides,
103     const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
104   std::vector<VarHandle> args = create_index_vars(dims);
105   ExprHandle body = body_func(args);
106   BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
107   return Tensor(buf, args, body);
108 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const std::vector<VarHandle> &)> & body_func)109 Tensor Compute(
110     const std::string& name,
111     const std::vector<ExprHandle>& dims,
112     const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
113   return Compute(name, dims, std::nullopt, body_func);
114 }
115 
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const VarHandle &)> & body_func)116 Tensor Compute(
117     const std::string& name,
118     const std::vector<ExprHandle>& dims,
119     const std::optional<std::vector<ExprHandle>>& strides,
120     const std::function<ExprHandle(const VarHandle&)>& body_func) {
121   if (dims.size() != 1) {
122     throw malformed_input("mismatch between body and arg size (1)");
123   }
124 
125   std::vector<VarHandle> args = create_index_vars(dims);
126   ExprHandle body = body_func(args[0]);
127   BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
128   return Tensor(buf, args, body);
129 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const VarHandle &)> & body_func)130 Tensor Compute(
131     const std::string& name,
132     const std::vector<ExprHandle>& dims,
133     const std::function<ExprHandle(const VarHandle&)>& body_func) {
134   return Compute(name, dims, std::nullopt, body_func);
135 }
136 
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const VarHandle &,const VarHandle &)> & body_func)137 Tensor Compute(
138     const std::string& name,
139     const std::vector<ExprHandle>& dims,
140     const std::optional<std::vector<ExprHandle>>& strides,
141     const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
142         body_func) {
143   if (dims.size() != 2) {
144     throw malformed_input("mismatch between body and arg size (2)");
145   }
146   std::vector<VarHandle> args = create_index_vars(dims);
147   ExprHandle body = body_func(args[0], args[1]);
148   BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
149   return Tensor(buf, args, body);
150 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const VarHandle &,const VarHandle &)> & body_func)151 Tensor Compute(
152     const std::string& name,
153     const std::vector<ExprHandle>& dims,
154     const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
155         body_func) {
156   return Compute(name, dims, std::nullopt, body_func);
157 }
158 
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &)> & body_func)159 Tensor Compute(
160     const std::string& name,
161     const std::vector<ExprHandle>& dims,
162     const std::optional<std::vector<ExprHandle>>& strides,
163     const std::function<
164         ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
165         body_func) {
166   if (dims.size() != 3) {
167     throw malformed_input("mismatch between body and arg size (3)");
168   }
169   std::vector<VarHandle> args = create_index_vars(dims);
170   ExprHandle body = body_func(args[0], args[1], args[2]);
171   BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
172   return Tensor(buf, args, body);
173 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &)> & body_func)174 Tensor Compute(
175     const std::string& name,
176     const std::vector<ExprHandle>& dims,
177     const std::function<
178         ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
179         body_func) {
180   return Compute(name, dims, std::nullopt, body_func);
181 }
182 
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &,const VarHandle &)> & body_func)183 Tensor Compute(
184     const std::string& name,
185     const std::vector<ExprHandle>& dims,
186     const std::optional<std::vector<ExprHandle>>& strides,
187     const std::function<ExprHandle(
188         const VarHandle&,
189         const VarHandle&,
190         const VarHandle&,
191         const VarHandle&)>& body_func) {
192   if (dims.size() != 4) {
193     throw malformed_input("mismatch between body and arg size (4)");
194   }
195   std::vector<VarHandle> args = create_index_vars(dims);
196   ExprHandle body = body_func(args[0], args[1], args[2], args[3]);
197   BufHandle buf = Buf::make(name, dims, body.dtype(), std::nullopt, strides);
198   return Tensor(buf, args, body);
199 }
Compute(const std::string & name,const std::vector<ExprHandle> & dims,const std::function<ExprHandle (const VarHandle &,const VarHandle &,const VarHandle &,const VarHandle &)> & body_func)200 Tensor Compute(
201     const std::string& name,
202     const std::vector<ExprHandle>& dims,
203     const std::function<ExprHandle(
204         const VarHandle&,
205         const VarHandle&,
206         const VarHandle&,
207         const VarHandle&)>& body_func) {
208   return Compute(name, dims, std::nullopt, body_func);
209 }
210 
Reduce(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const BufHandle & buffer,const std::vector<ExprHandle> & reduce_dims)211 Tensor Reduce(
212     const std::string& name,
213     const std::vector<ExprHandle>& dims,
214     const std::optional<std::vector<ExprHandle>>& strides,
215     const Reducer& reducer,
216     const BufHandle& buffer,
217     const std::vector<ExprHandle>& reduce_dims) {
218   return Reduce(
219       name,
220       dims,
221       strides,
222       reducer,
223       [&](ParameterList& p) { return buffer.load(p); },
224       reduce_dims);
225 }
Reduce(const std::string & name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const BufHandle & buffer,const std::vector<ExprHandle> & reduce_dims)226 Tensor Reduce(
227     const std::string& name,
228     const std::vector<ExprHandle>& dims,
229     const Reducer& reducer,
230     const BufHandle& buffer,
231     const std::vector<ExprHandle>& reduce_dims) {
232   return Reduce(name, dims, std::nullopt, reducer, buffer, reduce_dims);
233 }
234 
Reduce(const std::string & name,const std::vector<ExprHandle> & dims,const std::optional<std::vector<ExprHandle>> & strides,const Reducer & reducer,const Tensor & tensor,const std::vector<ExprHandle> & reduce_dims)235 Tensor Reduce(
236     const std::string& name,
237     const std::vector<ExprHandle>& dims,
238     const std::optional<std::vector<ExprHandle>>& strides,
239     const Reducer& reducer,
240     const Tensor& tensor,
241     const std::vector<ExprHandle>& reduce_dims) {
242   return Reduce(
243       name,
244       dims,
245       strides,
246       reducer,
247       [&](ParameterList& p) { return tensor.load(p); },
248       reduce_dims);
249 }
Reduce(const std::string & name,const std::vector<ExprHandle> & dims,const Reducer & reducer,const Tensor & tensor,const std::vector<ExprHandle> & reduce_dims)250 Tensor Reduce(
251     const std::string& name,
252     const std::vector<ExprHandle>& dims,
253     const Reducer& reducer,
254     const Tensor& tensor,
255     const std::vector<ExprHandle>& reduce_dims) {
256   return Reduce(name, dims, std::nullopt, reducer, tensor, reduce_dims);
257 }
258 
259 } // namespace torch::jit::tensorexpr
260