xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/te_wrapper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/codegen.h>
4 #include <torch/csrc/jit/tensorexpr/ir.h>
5 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
6 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
7 #include <torch/csrc/jit/tensorexpr/loopnest.h>
8 
9 namespace torch::jit {
10 
11 class TEWrapper {
12  public:
13   TEWrapper() = default;
14   void call(const std::vector<void*>& args);
15 
16   template <typename ExpectedType>
checkInput(const at::Tensor & t)17   bool checkInput(const at::Tensor& t) {
18 #ifdef TORCH_ENABLE_LLVM
19     return t.is_contiguous() && t.dtype().Match<ExpectedType>();
20 #else
21     return false;
22 #endif
23   }
24 
25 #ifdef TORCH_ENABLE_LLVM
26   void update(std::unique_ptr<tensorexpr::LLVMCodeGen>&& cg_);
27 #endif
28 
29  private:
30 #ifdef TORCH_ENABLE_LLVM
31   std::unique_ptr<tensorexpr::LLVMCodeGen> cg;
32 #endif
33 };
34 
35 std::shared_ptr<TEWrapper> createDiv();
36 std::shared_ptr<TEWrapper> createLogit();
37 std::shared_ptr<TEWrapper> createRelu();
38 std::shared_ptr<TEWrapper> createTanh();
39 std::shared_ptr<TEWrapper> createSigmoid();
40 std::shared_ptr<TEWrapper> createSignedLog1p();
41 std::shared_ptr<TEWrapper> createClamp();
42 std::shared_ptr<TEWrapper> createClampNanToNum();
43 
44 } // namespace torch::jit
45