xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/kernel.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
5 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
6 #include <torch/csrc/jit/runtime/interpreter.h>
7 #include <torch/csrc/jit/tensorexpr/analysis.h>
8 #include <torch/csrc/jit/tensorexpr/codegen.h>
9 #include <torch/csrc/jit/tensorexpr/lowerings.h>
10 #include <torch/csrc/jit/tensorexpr/tensor.h>
11 
12 namespace torch::jit::tensorexpr {
13 
14 struct SmallSizeTPairHash {
15  public:
operatorSmallSizeTPairHash16   std::size_t operator()(const std::pair<size_t, size_t>& x) const {
17     // hashing input index and then dim index
18     return x.first * 128 + x.second;
19   }
20 };
21 
22 // Returns true if the TE fuser supports this conv2d.
23 bool conv2dIsSupportedJit(const Node* node);
24 // Returns true if the TE fuser supports this conv2d with mkldnn prepacked conv.
25 bool mkldnnPrepackedConvIsSupportedJit(const Node* node);
26 // Returns true if the TE _convolution node is Conv2d.
27 bool isConv2d(const Node* node);
28 // Returns true if the TE fuser supports this matmul.
29 bool matmulIsSupported(const Node* node);
30 template <typename T>
bufferSizes(const T & t)31 inline std::vector<int64_t> bufferSizes(const T& t) {
32   std::vector<int64_t> sizes;
33   for (size_t i = 0; i < t->ndim(); i++) {
34     sizes.push_back(*intValue(t->dim(i)));
35   }
36   return sizes;
37 }
38 
39 // Get the dimensions of a value.
40 std::vector<ExprHandle> valueShape(const ArgValue& v);
41 
42 // If v is a tensor, broadcast it to match the shape of axes, or return
43 // directly if v is a constant.
44 ExprHandle tensorOrConstant(
45     const ArgValue& v,
46     const std::vector<ExprHandle>& axes);
47 
48 int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size);
49 
50 ExprHandle broadcast(const BufHandle& b, const std::vector<ExprHandle>& axes);
51 
52 ExprHandle constant(const ArgValue& v);
53 
54 std::vector<ExprHandle> computeIndicesToBroadcast(
55     const std::vector<ExprHandle>& outputAxes,
56     const std::vector<ExprHandle>& inputSizes);
57 
getArgValueName(const ArgValue & a)58 inline std::string getArgValueName(const ArgValue& a) {
59   if (std::holds_alternative<tensorexpr::BufHandle>(a)) {
60     return "BufHandle";
61   } else if (std::holds_alternative<tensorexpr::VarHandle>(a)) {
62     return "VarHandle";
63   } else if (std::holds_alternative<double>(a)) {
64     return "double";
65   } else if (std::holds_alternative<int64_t>(a)) {
66     return "int64_t";
67   } else if (std::holds_alternative<bool>(a)) {
68     return "bool";
69   } else if (std::holds_alternative<BufList>(a)) {
70     return "BufList";
71   } else if (std::holds_alternative<DoubleList>(a)) {
72     return "DoubleList";
73   } else if (std::holds_alternative<IntList>(a)) {
74     return "IntList";
75   } else if (std::holds_alternative<ArgNone>(a)) {
76     return "None";
77   } else {
78     throw std::runtime_error("ArgValue type not handled in string conversion");
79   }
80 }
81 
82 template <class T>
convertVecArgValue(const std::vector<ArgValue> & v)83 std::vector<T> convertVecArgValue(const std::vector<ArgValue>& v) {
84   std::vector<T> res;
85   for (auto& x : v) {
86     auto val = std::get_if<T>(&x);
87     if (val) {
88       res.push_back(*val);
89     } else {
90       throw std::runtime_error(
91           "vector type not homogeneous - found " + getArgValueName(x) +
92           ", expected " + getArgValueName(v[0]));
93     }
94   }
95   return res;
96 }
97 
98 class TORCH_API TensorExprKernel {
99   struct ConstantDescr {
100     BufPtr buf;
101     // Only one of ptr and node is used at a time
102     // 1) ptr for the constant tensors
103     // 2) node for the constant custom class objects
104     void* ptr = nullptr;
105     Node* node = nullptr;
106   };
107 
108  public:
109   // Constructor Params:
110   //  * subgraph
111   //      - the graph that needs to be compiled.
112   //  * kernel_func_name
113   //      - the name that should be used for the generated kernel.
114   //  * custom_lowerings
115   //      - map that represents custom lowering definitions for a set of ops.
116   //  * symbolic_shape_inputs
117   //      - a list of symbolic graph inputs that represent the symbolic dims of
118   //        the input tensors.
119   //  * pre_alloc
120   //      - a flag to control pre-allocation of buffers.
121   explicit TensorExprKernel(
122       const std::shared_ptr<Graph>& subgraph,
123       std::string kernel_func_name,
124       std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
125           {},
126       std::vector<int64_t> symbolic_shape_inputs = {},
127       bool pre_alloc = false,
128       std::unordered_map<
129           const torch::jit::Value*,
130           std::vector<torch::jit::StrideInput>> symbolic_strides = {});
131 
132   explicit TensorExprKernel(
133       const std::shared_ptr<Graph>& subgraph,
134       std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
135           {},
136       std::vector<int64_t> symbolic_shape_inputs = {},
137       bool pre_alloc = false,
138       std::unordered_map<
139           const torch::jit::Value*,
140           std::vector<torch::jit::StrideInput>> symbolic_strides = {})
TensorExprKernel(subgraph,SubgraphUtils::generateNameForGraph (subgraph),std::move (custom_lowerings),std::move (symbolic_shape_inputs),pre_alloc,std::move (symbolic_strides))141       : TensorExprKernel(
142             subgraph,
143             SubgraphUtils::generateNameForGraph(subgraph),
144             std::move(custom_lowerings),
145             std::move(symbolic_shape_inputs),
146             pre_alloc,
147             std::move(symbolic_strides)) {}
148 
149   void run(Stack& stack) const;
150   void runFast(
151       const std::vector<void*>& inputs,
152       const std::vector<void*>& outputs) const;
153   // Expected format of stack:
154   //  ... <outputs> <inputs>
155   // i.e., output IValues must be below the input IValues in the stack.
156   void runWithAllocatedOutputs(Stack& stack) const;
157 
fallback(Stack & stack)158   void fallback(Stack& stack) const {
159     InterpreterState(code_).run(stack);
160   }
161   void recompile();
162 
163   StmtPtr getCodeGenStmt();
164 
165   std::string getCodeText(const std::string& attr = "") {
166     return codegen_->getCodeText(attr);
167   }
168 
graph()169   const std::shared_ptr<Graph> graph() {
170     return graph_;
171   }
172 
getConstantDescriptors()173   const std::vector<ConstantDescr>& getConstantDescriptors() const {
174     return constants_;
175   }
176 
getBufferArgs()177   const std::vector<CodeGen::BufferArg>& getBufferArgs() const {
178     return bufferArgs_;
179   }
180 
getKernelName()181   const std::string& getKernelName() const {
182     return (codegen_ ? codegen_->kernel_func_name() : kernel_func_name_);
183   }
184 
getSymbolicShapeInputs()185   const std::vector<int64_t>& getSymbolicShapeInputs() const {
186     return symbolic_shape_inputs_;
187   }
188 
189  private:
190   enum BackendType {
191     kUninitialized,
192     kSimpleIREval,
193     kLLVMCodeGen,
194     kCudaCodeGen,
195     kBlockCodeGen,
196   };
197 
198   enum MemoryLayoutPolicy {
199     kContiguous,
200     kChannelsLastNdContiguous,
201   };
202 
203   void compile();
204   void genInputDebugNames();
205   void runKernel(Stack& stack) const;
206 
207   std::vector<ExprHandle> sizesForValue(const torch::jit::Value* v);
208 
209   // These functions broadcast shape and also store a `hasBroadcast_` variable.
210   std::vector<ExprHandle> broadcastShapesMut(
211       const std::vector<ExprHandle>& a,
212       const std::vector<ExprHandle>& b);
213   std::vector<ExprHandle> broadcastShapesMut(
214       std::vector<std::vector<ExprHandle>> shapes);
215 
216   ArgValue toArg(const torch::jit::Value* v) const;
217   ExprHandle constant(const torch::jit::Value* v);
218 
219   Tensor computeValue(const torch::jit::Value* v);
220 
221   void bindConstant(const torch::jit::Value* v);
222 
223   StmtPtr transformLoops(BackendType backendType, StmtPtr st);
224 
225   std::string getCodeGenName(BackendType backendType);
226 
227   void getStaticOutputSizesAndStrides(
228       const at::ArrayRef<IValue>& inputs,
229       std::vector<std::vector<int64_t>>* static_sizes,
230       std::vector<std::vector<int64_t>>* static_strides) const;
231 
232   std::vector<CodeGen::CallArg> prepareRunArgs(
233       const at::ArrayRef<IValue>& inputs,
234       std::vector<at::Tensor>& outputs) const;
235   BackendType inferBackendTypeFromDevice(at::Device device);
236 
237   Tensor bindInput(const torch::jit::Value* input);
238   BlockPtr bindAllInputs();
239 
240   // Deduce the memory layout policy to be propagated within
241   // NNC fusion group. The memory layout policy could be `kContiguous`
242   // or `kChannelsLastNdContiguous`.
243   //    `kContiguous`: Always convert the non-contiguous input tensors and
244   //        internal buffers to contiguous.
245   //    `kChannelsLastNdContiguous`: Always convert the input tensors and
246   //        internal buffers to channels-last contiguous.
247   // Currently, the rule is simple.
248   //    If all the input and out tensors of NNC fusion group are channels-last
249   //    contiguous, the policy is `kChannelsLastNdContiguous`. Otherwise, it
250   //    is always `kContiguous`.
251   void deduceMemoryLayoutPolicy();
252 
253   Tensor convertSymbolicOutputToCorrectStrides(torch::jit::Value* v);
254   Tensor convertStaticShapeOutputToCorrectStrides(torch::jit::Value* v);
255   Tensor convertSymbolicOutputToCorrectStrides(
256       const std::vector<ExprHandle>& sizes,
257       const std::vector<size_t>& sorted_stride_indices_descending,
258       const std::vector<ExprPtr>& strides,
259       BufPtr& buf);
260 
261   NNCLoweringFunction getCustomLoweringFor(c10::Symbol op) const;
getCustomLowerings()262   std::unordered_map<c10::Symbol, NNCLoweringFunction> getCustomLowerings()
263       const {
264     return custom_lowerings_;
265   }
266 
267   // Allocate memory for intermediate buffers at compile time.
268   // Specifically, we pre-allocate memory for intermediate buffers with static
269   // size and manage these buffers in the way we manage JIT constant tensors:
270   // push the buf args into the stack so NNC IR can access them at runtime.
271   std::vector<BufPtr> preAllocIntermediateBufs(
272       const std::vector<BufPtr>& interm_bufs);
273 
274   struct UnpackedTensorOptions {
275     std::optional<c10::ScalarType> dtype;
276     std::optional<c10::Layout> layout;
277     std::optional<c10::Device> device;
278     std::optional<bool> pinned_memory;
279 
UnpackedTensorOptionsUnpackedTensorOptions280     UnpackedTensorOptions(const c10::TensorOptions& opts)
281         : dtype(c10::optTypeMetaToScalarType(opts.dtype_opt())),
282           layout(opts.layout_opt()),
283           device(opts.device_opt()),
284           pinned_memory(opts.pinned_memory_opt()) {}
285   };
286 
287   ExprHandle getVarForShape(const c10::ShapeSymbol& ss);
288   std::vector<ExprHandle> computeInputTensorDims(
289       const torch::jit::Value* input);
290   ExprHandle getStrideArg(size_t tensor_input, size_t stride_index);
291   std::vector<ExprHandle> sizesFromSymbolicShape(
292       const c10::SymbolicShape& shape);
293   std::vector<ExprHandle> getInputStrides(
294       const torch::jit::Value* input,
295       const std::vector<ExprHandle>& inputTensorDims);
296   std::vector<torch::jit::StrideInput>& getSymbolicStrideDesc(
297       const torch::jit::Value* value);
298 
299   // Apply the optimizations to the graph owned by the current fusion group,
300   // like concatenation optimization, post-op fusion, and some other graph-level
301   // optimizations.
302   void optimizeOwningGraph();
303 
304   int64_t nInputs_ = 0;
305   int64_t nOutputs_ = 0;
306   std::vector<CodeGen::BufferArg> bufferArgs_;
307   std::vector<std::vector<int64_t>> tensorOutputSizes_;
308   std::vector<std::vector<int64_t>> tensorOutputStrides_;
309   std::vector<torch::jit::StrideInput> tensorOutputStrideDesc_;
310   std::vector<bool> isOutputScalar_;
311   std::vector<UnpackedTensorOptions> tensorOutputTensorOptions_;
312   std::unordered_set<BufPtr> bufOutputs_;
313   std::unordered_set<BufPtr> bufsToBeParallelized_;
314   std::unordered_map<const torch::jit::Value*, BufPtr> bufs_;
315   std::unordered_map<const torch::jit::Value*, VarHandle> scalars_;
316   std::unordered_map<const torch::jit::Value*, std::string> input_name_map_;
317   std::unique_ptr<CodeGen> codegen_;
318   at::Device device_ = at::kCPU;
319   std::shared_ptr<Graph> graph_;
320   Code code_;
321   bool allow_fallback_{false};
322   bool use_fallback_{false};
323   bool hasRandom_{false};
324   bool hasBroadcast_{false};
325   std::unordered_map<const torch::jit::Value*, std::vector<ExprHandle>>
326       known_sizes_;
327 
328   std::vector<std::vector<ExprHandle>> tensorOutputSymbolicSizes_;
329   // A map from ShapeSymbol.value() to the corresponding Var.
330   std::unordered_map<int64_t, VarHandle> shapeSymbolToVar_;
331   std::unordered_map<ExprPtr, size_t> shapeSymbolInputPos_;
332   // List of values corresponding to the ShapeSymbols that are inputs to
333   // kernel being compiled. The order of these values correspond to the order
334   // of the symbolic inputs at the end of the list of inputs to the kernel.
335   std::vector<int64_t> symbolic_shape_inputs_;
336   bool has_symbolic_shapes_{false};
337 
338   std::vector<at::Tensor> unpacked_constant_tensors_;
339   std::vector<ConstantDescr> constants_;
340 
341   std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings_;
342   StmtPtr stmt_ = nullptr;
343   bool pre_alloc_{false};
344   std::string kernel_func_name_;
345 
346   // index of stack, stride index of tensor that will be appended as a codegen
347   // arg
348   std::vector<std::pair<size_t, size_t>> input_stride_args_;
349   // map from <input index, tensor dimension> to stride as arg VarHandle
350   std::unordered_map<std::pair<size_t, size_t>, VarHandle, SmallSizeTPairHash>
351       strideArgToVar_;
352   std::unordered_map<
353       const torch::jit::Value*,
354       std::vector<torch::jit::StrideInput>>
355       symbolic_strides_;
356 
357   // Memory layout to be propagated with fusion group
358   MemoryLayoutPolicy memory_layout_policy_ = MemoryLayoutPolicy::kContiguous;
359 };
360 
361 TORCH_API int& getTECudaPointwiseLoopLevels();
362 TORCH_API int& getTECudaPointwiseBlockCount();
363 TORCH_API int& getTECudaPointwiseBlockSize();
364 TORCH_API bool& getTEGenerateBlockCode();
365 TORCH_API bool& getTEMustUseLLVMOnCPU();
366 TORCH_API bool fallbackAllowed();
367 TORCH_API bool setFallbackAllowed(bool value);
368 TORCH_API bool& getCatWoConditionals();
369 TORCH_API bool& getOptConditionals();
370 
371 TORCH_API std::optional<at::Device> pickDeviceType(
372     const at::ArrayRef<torch::jit::Value*>& inputs);
373 
374 bool isContiguous(
375     const torch::jit::Value* v,
376     at::MemoryFormat memory_format = at::MemoryFormat::Contiguous);
377 
378 } // namespace torch::jit::tensorexpr
379