xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/lowerings.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // This file defines classes for registering standard lowerings from JIT to TE
2 // IR.
3 #pragma once
4 
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/runtime/interpreter.h>
7 #include <torch/csrc/jit/tensorexpr/analysis.h>
8 #include <torch/csrc/jit/tensorexpr/codegen.h>
9 #include <torch/csrc/jit/tensorexpr/tensor.h>
10 
11 namespace torch::jit::tensorexpr {
12 
13 using ArgNone = std::monostate;
14 using BufList = std::vector<tensorexpr::BufHandle>;
15 using DoubleList = std::vector<double>;
16 using IntList = std::vector<int64_t>;
17 using ArgValue = std::variant<
18     tensorexpr::BufHandle,
19     tensorexpr::VarHandle,
20     double,
21     int64_t,
22     bool,
23     BufList,
24     DoubleList,
25     IntList,
26     std::string,
27     ArgNone>;
28 
29 using NNCLoweringFunction = std::function<Tensor(
30     const std::vector<ArgValue>&,
31     const std::vector<ExprHandle>&,
32     const std::vector<ExprHandle>&,
33     const std::optional<ScalarType>&,
34     at::Device)>;
35 
36 TORCH_API FunctionSchemaMap<NNCLoweringFunction>& getNNCLoweringRegistry();
37 TORCH_API NNCLoweringFunction getStandardLoweringFor(const std::string& op);
38 
39 struct RegisterNNCLoweringsFunction {
40   RegisterNNCLoweringsFunction(
41       const std::vector<std::string>& schemas,
42       const NNCLoweringFunction& fn);
43 };
44 
45 } // namespace torch::jit::tensorexpr
46