1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/codegen.h> 4 #include <torch/csrc/jit/tensorexpr/ir.h> 5 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h> 6 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h> 7 #include <torch/csrc/jit/tensorexpr/loopnest.h> 8 9 namespace torch::jit { 10 11 class TEWrapper { 12 public: 13 TEWrapper() = default; 14 void call(const std::vector<void*>& args); 15 16 template <typename ExpectedType> checkInput(const at::Tensor & t)17 bool checkInput(const at::Tensor& t) { 18 #ifdef TORCH_ENABLE_LLVM 19 return t.is_contiguous() && t.dtype().Match<ExpectedType>(); 20 #else 21 return false; 22 #endif 23 } 24 25 #ifdef TORCH_ENABLE_LLVM 26 void update(std::unique_ptr<tensorexpr::LLVMCodeGen>&& cg_); 27 #endif 28 29 private: 30 #ifdef TORCH_ENABLE_LLVM 31 std::unique_ptr<tensorexpr::LLVMCodeGen> cg; 32 #endif 33 }; 34 35 std::shared_ptr<TEWrapper> createDiv(); 36 std::shared_ptr<TEWrapper> createLogit(); 37 std::shared_ptr<TEWrapper> createRelu(); 38 std::shared_ptr<TEWrapper> createTanh(); 39 std::shared_ptr<TEWrapper> createSigmoid(); 40 std::shared_ptr<TEWrapper> createSignedLog1p(); 41 std::shared_ptr<TEWrapper> createClamp(); 42 std::shared_ptr<TEWrapper> createClampNanToNum(); 43 44 } // namespace torch::jit 45