1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/fwd_decls.h> 4 #include <torch/csrc/jit/tensorexpr/lowerings.h> 5 #include <torch/csrc/jit/tensorexpr/tensor.h> 6 7 namespace torch::jit::tensorexpr { 8 9 struct TensorInfo { 10 std::vector<int64_t> dims; 11 c10::ScalarType dtype; 12 }; 13 std::optional<TensorInfo> getTensorInfo(const BufHandle& b); 14 15 int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size); 16 17 // Convert boolean to integer, if needed. 18 ExprHandle boolToInteger(const ExprHandle& x); 19 ExprHandle promoteToDtype(ExprHandle e, ScalarType dt); 20 void promoteInputs( 21 std::vector<ExprHandle>& inputs, 22 const int typeConstraints = kAllTypes); 23 ExprHandle promoteIntegerToDefaultType(const ExprHandle& e); 24 ExprHandle promoteHalfToFloat(const ExprHandle& e); 25 ExprHandle demoteOutput( 26 const ExprHandle& e, 27 const std::optional<ScalarType> type); 28 29 std::vector<ExprHandle> broadcastShapes( 30 std::vector<std::vector<ExprHandle>> shapes); 31 std::vector<ExprHandle> broadcastShapes( 32 const std::vector<ExprHandle>& a, 33 const std::vector<ExprHandle>& b); 34 35 std::vector<ExprHandle> valueShape(const ArgValue& v); 36 ExprHandle tensorOrConstant( 37 const ArgValue& v, 38 const std::vector<ExprHandle>& axes); 39 ExprHandle scalarOrConstant(const ArgValue& v); 40 ExprHandle broadcast(const BufHandle& b, const std::vector<ExprHandle>& axes); 41 ExprHandle constant(const ArgValue& v); 42 43 ExprHandle clamp( 44 const ExprHandle& cmin, 45 const ExprHandle& cmax, 46 const ExprHandle& input); 47 48 Tensor computeChunk( 49 const std::vector<ArgValue>& inputs, 50 const std::vector<ExprHandle>& outputShape, 51 const std::vector<ExprHandle>& outputStrides, 52 const std::optional<ScalarType>& outputType, 53 at::Device device); 54 Tensor computeTranspose( 55 const std::vector<ArgValue>& inputs, 56 const std::vector<ExprHandle>& outputShape, 57 const std::vector<ExprHandle>& outputStrides, 58 const std::optional<ScalarType>& outputType, 59 at::Device device); 60 Tensor computeExpand( 61 const std::vector<ArgValue>& inputs, 62 const std::vector<ExprHandle>& outputShape, 63 const std::vector<ExprHandle>& outputStrides, 64 const std::optional<ScalarType>& outputType, 65 at::Device device); 66 Tensor computeReshape( 67 const std::vector<ArgValue>& inputs, 68 const std::vector<ExprHandle>& outputShape, 69 const std::vector<ExprHandle>& outputStrides, 70 const std::optional<ScalarType>& outputType, 71 at::Device device); 72 Tensor computeFlatten( 73 const std::vector<ArgValue>& inputs, 74 const std::vector<ExprHandle>& outputShape, 75 const std::vector<ExprHandle>& outputStrides, 76 const std::optional<ScalarType>& outputType, 77 at::Device device); 78 Tensor computeCatWoConditionals( 79 const std::vector<ArgValue>& inputs, 80 const std::vector<ExprHandle>& outputShape); 81 Tensor computeCat( 82 const std::vector<ArgValue>& inputs, 83 const std::vector<ExprHandle>& outputShape, 84 const std::vector<ExprHandle>& outputStrides, 85 const std::optional<ScalarType>& outputType, 86 at::Device device); 87 Tensor computeEmbedding( 88 const std::vector<ArgValue>& inputs, 89 const std::vector<ExprHandle>& outputShape, 90 const std::vector<ExprHandle>& outputStrides, 91 const std::optional<ScalarType>& outputType, 92 at::Device device); 93 94 } // namespace torch::jit::tensorexpr 95