1 #pragma once
2
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
5 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
6 #include <torch/csrc/jit/runtime/interpreter.h>
7 #include <torch/csrc/jit/tensorexpr/analysis.h>
8 #include <torch/csrc/jit/tensorexpr/codegen.h>
9 #include <torch/csrc/jit/tensorexpr/lowerings.h>
10 #include <torch/csrc/jit/tensorexpr/tensor.h>
11
12 namespace torch::jit::tensorexpr {
13
14 struct SmallSizeTPairHash {
15 public:
operatorSmallSizeTPairHash16 std::size_t operator()(const std::pair<size_t, size_t>& x) const {
17 // hashing input index and then dim index
18 return x.first * 128 + x.second;
19 }
20 };
21
22 // Returns true if the TE fuser supports this conv2d.
23 bool conv2dIsSupportedJit(const Node* node);
24 // Returns true if the TE fuser supports this conv2d with mkldnn prepacked conv.
25 bool mkldnnPrepackedConvIsSupportedJit(const Node* node);
26 // Returns true if the TE _convolution node is Conv2d.
27 bool isConv2d(const Node* node);
28 // Returns true if the TE fuser supports this matmul.
29 bool matmulIsSupported(const Node* node);
30 template <typename T>
bufferSizes(const T & t)31 inline std::vector<int64_t> bufferSizes(const T& t) {
32 std::vector<int64_t> sizes;
33 for (size_t i = 0; i < t->ndim(); i++) {
34 sizes.push_back(*intValue(t->dim(i)));
35 }
36 return sizes;
37 }
38
39 // Get the dimensions of a value.
40 std::vector<ExprHandle> valueShape(const ArgValue& v);
41
42 // If v is a tensor, broadcast it to match the shape of axes, or return
43 // directly if v is a constant.
44 ExprHandle tensorOrConstant(
45 const ArgValue& v,
46 const std::vector<ExprHandle>& axes);
47
48 int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size);
49
50 ExprHandle broadcast(const BufHandle& b, const std::vector<ExprHandle>& axes);
51
52 ExprHandle constant(const ArgValue& v);
53
54 std::vector<ExprHandle> computeIndicesToBroadcast(
55 const std::vector<ExprHandle>& outputAxes,
56 const std::vector<ExprHandle>& inputSizes);
57
getArgValueName(const ArgValue & a)58 inline std::string getArgValueName(const ArgValue& a) {
59 if (std::holds_alternative<tensorexpr::BufHandle>(a)) {
60 return "BufHandle";
61 } else if (std::holds_alternative<tensorexpr::VarHandle>(a)) {
62 return "VarHandle";
63 } else if (std::holds_alternative<double>(a)) {
64 return "double";
65 } else if (std::holds_alternative<int64_t>(a)) {
66 return "int64_t";
67 } else if (std::holds_alternative<bool>(a)) {
68 return "bool";
69 } else if (std::holds_alternative<BufList>(a)) {
70 return "BufList";
71 } else if (std::holds_alternative<DoubleList>(a)) {
72 return "DoubleList";
73 } else if (std::holds_alternative<IntList>(a)) {
74 return "IntList";
75 } else if (std::holds_alternative<ArgNone>(a)) {
76 return "None";
77 } else {
78 throw std::runtime_error("ArgValue type not handled in string conversion");
79 }
80 }
81
82 template <class T>
convertVecArgValue(const std::vector<ArgValue> & v)83 std::vector<T> convertVecArgValue(const std::vector<ArgValue>& v) {
84 std::vector<T> res;
85 for (auto& x : v) {
86 auto val = std::get_if<T>(&x);
87 if (val) {
88 res.push_back(*val);
89 } else {
90 throw std::runtime_error(
91 "vector type not homogeneous - found " + getArgValueName(x) +
92 ", expected " + getArgValueName(v[0]));
93 }
94 }
95 return res;
96 }
97
98 class TORCH_API TensorExprKernel {
99 struct ConstantDescr {
100 BufPtr buf;
101 // Only one of ptr and node is used at a time
102 // 1) ptr for the constant tensors
103 // 2) node for the constant custom class objects
104 void* ptr = nullptr;
105 Node* node = nullptr;
106 };
107
108 public:
109 // Constructor Params:
110 // * subgraph
111 // - the graph that needs to be compiled.
112 // * kernel_func_name
113 // - the name that should be used for the generated kernel.
114 // * custom_lowerings
115 // - map that represents custom lowering definitions for a set of ops.
116 // * symbolic_shape_inputs
117 // - a list of symbolic graph inputs that represent the symbolic dims of
118 // the input tensors.
119 // * pre_alloc
120 // - a flag to control pre-allocation of buffers.
121 explicit TensorExprKernel(
122 const std::shared_ptr<Graph>& subgraph,
123 std::string kernel_func_name,
124 std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
125 {},
126 std::vector<int64_t> symbolic_shape_inputs = {},
127 bool pre_alloc = false,
128 std::unordered_map<
129 const torch::jit::Value*,
130 std::vector<torch::jit::StrideInput>> symbolic_strides = {});
131
132 explicit TensorExprKernel(
133 const std::shared_ptr<Graph>& subgraph,
134 std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
135 {},
136 std::vector<int64_t> symbolic_shape_inputs = {},
137 bool pre_alloc = false,
138 std::unordered_map<
139 const torch::jit::Value*,
140 std::vector<torch::jit::StrideInput>> symbolic_strides = {})
TensorExprKernel(subgraph,SubgraphUtils::generateNameForGraph (subgraph),std::move (custom_lowerings),std::move (symbolic_shape_inputs),pre_alloc,std::move (symbolic_strides))141 : TensorExprKernel(
142 subgraph,
143 SubgraphUtils::generateNameForGraph(subgraph),
144 std::move(custom_lowerings),
145 std::move(symbolic_shape_inputs),
146 pre_alloc,
147 std::move(symbolic_strides)) {}
148
149 void run(Stack& stack) const;
150 void runFast(
151 const std::vector<void*>& inputs,
152 const std::vector<void*>& outputs) const;
153 // Expected format of stack:
154 // ... <outputs> <inputs>
155 // i.e., output IValues must be below the input IValues in the stack.
156 void runWithAllocatedOutputs(Stack& stack) const;
157
fallback(Stack & stack)158 void fallback(Stack& stack) const {
159 InterpreterState(code_).run(stack);
160 }
161 void recompile();
162
163 StmtPtr getCodeGenStmt();
164
165 std::string getCodeText(const std::string& attr = "") {
166 return codegen_->getCodeText(attr);
167 }
168
graph()169 const std::shared_ptr<Graph> graph() {
170 return graph_;
171 }
172
getConstantDescriptors()173 const std::vector<ConstantDescr>& getConstantDescriptors() const {
174 return constants_;
175 }
176
getBufferArgs()177 const std::vector<CodeGen::BufferArg>& getBufferArgs() const {
178 return bufferArgs_;
179 }
180
getKernelName()181 const std::string& getKernelName() const {
182 return (codegen_ ? codegen_->kernel_func_name() : kernel_func_name_);
183 }
184
getSymbolicShapeInputs()185 const std::vector<int64_t>& getSymbolicShapeInputs() const {
186 return symbolic_shape_inputs_;
187 }
188
189 private:
190 enum BackendType {
191 kUninitialized,
192 kSimpleIREval,
193 kLLVMCodeGen,
194 kCudaCodeGen,
195 kBlockCodeGen,
196 };
197
198 enum MemoryLayoutPolicy {
199 kContiguous,
200 kChannelsLastNdContiguous,
201 };
202
203 void compile();
204 void genInputDebugNames();
205 void runKernel(Stack& stack) const;
206
207 std::vector<ExprHandle> sizesForValue(const torch::jit::Value* v);
208
209 // These functions broadcast shape and also store a `hasBroadcast_` variable.
210 std::vector<ExprHandle> broadcastShapesMut(
211 const std::vector<ExprHandle>& a,
212 const std::vector<ExprHandle>& b);
213 std::vector<ExprHandle> broadcastShapesMut(
214 std::vector<std::vector<ExprHandle>> shapes);
215
216 ArgValue toArg(const torch::jit::Value* v) const;
217 ExprHandle constant(const torch::jit::Value* v);
218
219 Tensor computeValue(const torch::jit::Value* v);
220
221 void bindConstant(const torch::jit::Value* v);
222
223 StmtPtr transformLoops(BackendType backendType, StmtPtr st);
224
225 std::string getCodeGenName(BackendType backendType);
226
227 void getStaticOutputSizesAndStrides(
228 const at::ArrayRef<IValue>& inputs,
229 std::vector<std::vector<int64_t>>* static_sizes,
230 std::vector<std::vector<int64_t>>* static_strides) const;
231
232 std::vector<CodeGen::CallArg> prepareRunArgs(
233 const at::ArrayRef<IValue>& inputs,
234 std::vector<at::Tensor>& outputs) const;
235 BackendType inferBackendTypeFromDevice(at::Device device);
236
237 Tensor bindInput(const torch::jit::Value* input);
238 BlockPtr bindAllInputs();
239
240 // Deduce the memory layout policy to be propagated within
241 // NNC fusion group. The memory layout policy could be `kContiguous`
242 // or `kChannelsLastNdContiguous`.
243 // `kContiguous`: Always convert the non-contiguous input tensors and
244 // internal buffers to contiguous.
245 // `kChannelsLastNdContiguous`: Always convert the input tensors and
246 // internal buffers to channels-last contiguous.
247 // Currently, the rule is simple.
248 // If all the input and out tensors of NNC fusion group are channels-last
249 // contiguous, the policy is `kChannelsLastNdContiguous`. Otherwise, it
250 // is always `kContiguous`.
251 void deduceMemoryLayoutPolicy();
252
253 Tensor convertSymbolicOutputToCorrectStrides(torch::jit::Value* v);
254 Tensor convertStaticShapeOutputToCorrectStrides(torch::jit::Value* v);
255 Tensor convertSymbolicOutputToCorrectStrides(
256 const std::vector<ExprHandle>& sizes,
257 const std::vector<size_t>& sorted_stride_indices_descending,
258 const std::vector<ExprPtr>& strides,
259 BufPtr& buf);
260
261 NNCLoweringFunction getCustomLoweringFor(c10::Symbol op) const;
getCustomLowerings()262 std::unordered_map<c10::Symbol, NNCLoweringFunction> getCustomLowerings()
263 const {
264 return custom_lowerings_;
265 }
266
267 // Allocate memory for intermediate buffers at compile time.
268 // Specifically, we pre-allocate memory for intermediate buffers with static
269 // size and manage these buffers in the way we manage JIT constant tensors:
270 // push the buf args into the stack so NNC IR can access them at runtime.
271 std::vector<BufPtr> preAllocIntermediateBufs(
272 const std::vector<BufPtr>& interm_bufs);
273
274 struct UnpackedTensorOptions {
275 std::optional<c10::ScalarType> dtype;
276 std::optional<c10::Layout> layout;
277 std::optional<c10::Device> device;
278 std::optional<bool> pinned_memory;
279
UnpackedTensorOptionsUnpackedTensorOptions280 UnpackedTensorOptions(const c10::TensorOptions& opts)
281 : dtype(c10::optTypeMetaToScalarType(opts.dtype_opt())),
282 layout(opts.layout_opt()),
283 device(opts.device_opt()),
284 pinned_memory(opts.pinned_memory_opt()) {}
285 };
286
287 ExprHandle getVarForShape(const c10::ShapeSymbol& ss);
288 std::vector<ExprHandle> computeInputTensorDims(
289 const torch::jit::Value* input);
290 ExprHandle getStrideArg(size_t tensor_input, size_t stride_index);
291 std::vector<ExprHandle> sizesFromSymbolicShape(
292 const c10::SymbolicShape& shape);
293 std::vector<ExprHandle> getInputStrides(
294 const torch::jit::Value* input,
295 const std::vector<ExprHandle>& inputTensorDims);
296 std::vector<torch::jit::StrideInput>& getSymbolicStrideDesc(
297 const torch::jit::Value* value);
298
299 // Apply the optimizations to the graph owned by the current fusion group,
300 // like concatenation optimization, post-op fusion, and some other graph-level
301 // optimizations.
302 void optimizeOwningGraph();
303
304 int64_t nInputs_ = 0;
305 int64_t nOutputs_ = 0;
306 std::vector<CodeGen::BufferArg> bufferArgs_;
307 std::vector<std::vector<int64_t>> tensorOutputSizes_;
308 std::vector<std::vector<int64_t>> tensorOutputStrides_;
309 std::vector<torch::jit::StrideInput> tensorOutputStrideDesc_;
310 std::vector<bool> isOutputScalar_;
311 std::vector<UnpackedTensorOptions> tensorOutputTensorOptions_;
312 std::unordered_set<BufPtr> bufOutputs_;
313 std::unordered_set<BufPtr> bufsToBeParallelized_;
314 std::unordered_map<const torch::jit::Value*, BufPtr> bufs_;
315 std::unordered_map<const torch::jit::Value*, VarHandle> scalars_;
316 std::unordered_map<const torch::jit::Value*, std::string> input_name_map_;
317 std::unique_ptr<CodeGen> codegen_;
318 at::Device device_ = at::kCPU;
319 std::shared_ptr<Graph> graph_;
320 Code code_;
321 bool allow_fallback_{false};
322 bool use_fallback_{false};
323 bool hasRandom_{false};
324 bool hasBroadcast_{false};
325 std::unordered_map<const torch::jit::Value*, std::vector<ExprHandle>>
326 known_sizes_;
327
328 std::vector<std::vector<ExprHandle>> tensorOutputSymbolicSizes_;
329 // A map from ShapeSymbol.value() to the corresponding Var.
330 std::unordered_map<int64_t, VarHandle> shapeSymbolToVar_;
331 std::unordered_map<ExprPtr, size_t> shapeSymbolInputPos_;
332 // List of values corresponding to the ShapeSymbols that are inputs to
333 // kernel being compiled. The order of these values correspond to the order
334 // of the symbolic inputs at the end of the list of inputs to the kernel.
335 std::vector<int64_t> symbolic_shape_inputs_;
336 bool has_symbolic_shapes_{false};
337
338 std::vector<at::Tensor> unpacked_constant_tensors_;
339 std::vector<ConstantDescr> constants_;
340
341 std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings_;
342 StmtPtr stmt_ = nullptr;
343 bool pre_alloc_{false};
344 std::string kernel_func_name_;
345
346 // index of stack, stride index of tensor that will be appended as a codegen
347 // arg
348 std::vector<std::pair<size_t, size_t>> input_stride_args_;
349 // map from <input index, tensor dimension> to stride as arg VarHandle
350 std::unordered_map<std::pair<size_t, size_t>, VarHandle, SmallSizeTPairHash>
351 strideArgToVar_;
352 std::unordered_map<
353 const torch::jit::Value*,
354 std::vector<torch::jit::StrideInput>>
355 symbolic_strides_;
356
357 // Memory layout to be propagated with fusion group
358 MemoryLayoutPolicy memory_layout_policy_ = MemoryLayoutPolicy::kContiguous;
359 };
360
361 TORCH_API int& getTECudaPointwiseLoopLevels();
362 TORCH_API int& getTECudaPointwiseBlockCount();
363 TORCH_API int& getTECudaPointwiseBlockSize();
364 TORCH_API bool& getTEGenerateBlockCode();
365 TORCH_API bool& getTEMustUseLLVMOnCPU();
366 TORCH_API bool fallbackAllowed();
367 TORCH_API bool setFallbackAllowed(bool value);
368 TORCH_API bool& getCatWoConditionals();
369 TORCH_API bool& getOptConditionals();
370
371 TORCH_API std::optional<at::Device> pickDeviceType(
372 const at::ArrayRef<torch::jit::Value*>& inputs);
373
374 bool isContiguous(
375 const torch::jit::Value* v,
376 at::MemoryFormat memory_format = at::MemoryFormat::Contiguous);
377
378 } // namespace torch::jit::tensorexpr
379