1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/kernel.h> 4 5 namespace torch { 6 namespace jit { 7 namespace tensorexpr { 8 9 TORCH_API Tensor computeSign( 10 const std::vector<ArgValue>& inputs, 11 const std::vector<ExprHandle>& outputShape, 12 const std::optional<std::vector<ExprHandle>>& outputStrides = std::nullopt); 13 14 Tensor computeOneOperand( 15 const std::string& name, 16 const std::vector<ArgValue>& inputValues, 17 const std::vector<ExprHandle>& outputShape, 18 const std::vector<ExprHandle>& outputStrides, 19 const std::optional<ScalarType>& outputType, 20 const std::function<ExprHandle(const ExprHandle&)>& innerExpr, 21 const int checkParamTypes = kAllTypes); 22 Tensor computeTwoOperand( 23 const std::string& name, 24 const std::vector<ArgValue>& inputValues, 25 const std::vector<ExprHandle>& outputShape, 26 const std::vector<ExprHandle>& outputStrides, 27 const std::optional<ScalarType>& outputType, 28 const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>& 29 innerExpr); 30 Tensor computeTwoOperandWithAlpha( 31 const std::string& name, 32 const std::vector<ArgValue>& inputValues, 33 const std::vector<ExprHandle>& outputShape, 34 const std::vector<ExprHandle>& outputStrides, 35 const std::optional<ScalarType>& outputType, 36 const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>& 37 innerExpr); 38 Tensor computeConditionWithTwoOperand( 39 const std::string& name, 40 const std::vector<ArgValue>& inputValues, 41 const std::vector<ExprHandle>& outputShape, 42 const std::vector<ExprHandle>& outputStrides, 43 const std::optional<ScalarType>& outputType, 44 const std::function< 45 ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& 46 innerExpr); 47 Tensor computeThreeOperand( 48 const std::string& name, 49 const std::vector<ArgValue>& inputValues, 50 const std::vector<ExprHandle>& outputShape, 51 const std::vector<ExprHandle>& outputStrides, 52 const std::optional<ScalarType>& outputType, 53 const std::function< 54 ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>& 55 innerExpr, 56 bool promote_inputs = true); 57 Tensor computeFourOperand( 58 const std::string& name, 59 const std::vector<ArgValue>& inputValues, 60 const std::vector<ExprHandle>& outputShape, 61 const std::vector<ExprHandle>& outputStrides, 62 const std::optional<ScalarType>& outputType, 63 const std::function<ExprHandle( 64 const ExprHandle&, 65 const ExprHandle&, 66 const ExprHandle&, 67 const ExprHandle&)>& innerExpr); 68 Tensor computeNoop( 69 const std::vector<ArgValue>& inputs, 70 const std::vector<ExprHandle>& outputShape, 71 const std::vector<ExprHandle>& outputStrides, 72 const std::optional<ScalarType>& outputType, 73 at::Device device); 74 75 Tensor computeScalar( 76 const std::string& name, 77 const std::vector<ArgValue>& inputValues, 78 const std::vector<ExprHandle>& outputShape, 79 const std::vector<ExprHandle>& outputStrides, 80 const std::optional<ScalarType>& outputType, 81 const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>& 82 innerExpr); 83 84 } // namespace tensorexpr 85 } // namespace jit 86 } // namespace torch 87