1 // This file defines classes for registering standard lowerings from JIT to TE 2 // IR. 3 #pragma once 4 5 #include <torch/csrc/jit/ir/ir.h> 6 #include <torch/csrc/jit/runtime/interpreter.h> 7 #include <torch/csrc/jit/tensorexpr/analysis.h> 8 #include <torch/csrc/jit/tensorexpr/codegen.h> 9 #include <torch/csrc/jit/tensorexpr/tensor.h> 10 11 namespace torch::jit::tensorexpr { 12 13 using ArgNone = std::monostate; 14 using BufList = std::vector<tensorexpr::BufHandle>; 15 using DoubleList = std::vector<double>; 16 using IntList = std::vector<int64_t>; 17 using ArgValue = std::variant< 18 tensorexpr::BufHandle, 19 tensorexpr::VarHandle, 20 double, 21 int64_t, 22 bool, 23 BufList, 24 DoubleList, 25 IntList, 26 std::string, 27 ArgNone>; 28 29 using NNCLoweringFunction = std::function<Tensor( 30 const std::vector<ArgValue>&, 31 const std::vector<ExprHandle>&, 32 const std::vector<ExprHandle>&, 33 const std::optional<ScalarType>&, 34 at::Device)>; 35 36 TORCH_API FunctionSchemaMap<NNCLoweringFunction>& getNNCLoweringRegistry(); 37 TORCH_API NNCLoweringFunction getStandardLoweringFor(const std::string& op); 38 39 struct RegisterNNCLoweringsFunction { 40 RegisterNNCLoweringsFunction( 41 const std::vector<std::string>& schemas, 42 const NNCLoweringFunction& fn); 43 }; 44 45 } // namespace torch::jit::tensorexpr 46