xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/misc.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
4 #include <torch/csrc/jit/tensorexpr/lowerings.h>
5 #include <torch/csrc/jit/tensorexpr/tensor.h>
6 
7 namespace torch::jit::tensorexpr {
8 
9 struct TensorInfo {
10   std::vector<int64_t> dims;
11   c10::ScalarType dtype;
12 };
13 std::optional<TensorInfo> getTensorInfo(const BufHandle& b);
14 
15 int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size);
16 
17 // Convert boolean to integer, if needed.
18 ExprHandle boolToInteger(const ExprHandle& x);
19 ExprHandle promoteToDtype(ExprHandle e, ScalarType dt);
20 void promoteInputs(
21     std::vector<ExprHandle>& inputs,
22     const int typeConstraints = kAllTypes);
23 ExprHandle promoteIntegerToDefaultType(const ExprHandle& e);
24 ExprHandle promoteHalfToFloat(const ExprHandle& e);
25 ExprHandle demoteOutput(
26     const ExprHandle& e,
27     const std::optional<ScalarType> type);
28 
29 std::vector<ExprHandle> broadcastShapes(
30     std::vector<std::vector<ExprHandle>> shapes);
31 std::vector<ExprHandle> broadcastShapes(
32     const std::vector<ExprHandle>& a,
33     const std::vector<ExprHandle>& b);
34 
35 std::vector<ExprHandle> valueShape(const ArgValue& v);
36 ExprHandle tensorOrConstant(
37     const ArgValue& v,
38     const std::vector<ExprHandle>& axes);
39 ExprHandle scalarOrConstant(const ArgValue& v);
40 ExprHandle broadcast(const BufHandle& b, const std::vector<ExprHandle>& axes);
41 ExprHandle constant(const ArgValue& v);
42 
43 ExprHandle clamp(
44     const ExprHandle& cmin,
45     const ExprHandle& cmax,
46     const ExprHandle& input);
47 
48 Tensor computeChunk(
49     const std::vector<ArgValue>& inputs,
50     const std::vector<ExprHandle>& outputShape,
51     const std::vector<ExprHandle>& outputStrides,
52     const std::optional<ScalarType>& outputType,
53     at::Device device);
54 Tensor computeTranspose(
55     const std::vector<ArgValue>& inputs,
56     const std::vector<ExprHandle>& outputShape,
57     const std::vector<ExprHandle>& outputStrides,
58     const std::optional<ScalarType>& outputType,
59     at::Device device);
60 Tensor computeExpand(
61     const std::vector<ArgValue>& inputs,
62     const std::vector<ExprHandle>& outputShape,
63     const std::vector<ExprHandle>& outputStrides,
64     const std::optional<ScalarType>& outputType,
65     at::Device device);
66 Tensor computeReshape(
67     const std::vector<ArgValue>& inputs,
68     const std::vector<ExprHandle>& outputShape,
69     const std::vector<ExprHandle>& outputStrides,
70     const std::optional<ScalarType>& outputType,
71     at::Device device);
72 Tensor computeFlatten(
73     const std::vector<ArgValue>& inputs,
74     const std::vector<ExprHandle>& outputShape,
75     const std::vector<ExprHandle>& outputStrides,
76     const std::optional<ScalarType>& outputType,
77     at::Device device);
78 Tensor computeCatWoConditionals(
79     const std::vector<ArgValue>& inputs,
80     const std::vector<ExprHandle>& outputShape);
81 Tensor computeCat(
82     const std::vector<ArgValue>& inputs,
83     const std::vector<ExprHandle>& outputShape,
84     const std::vector<ExprHandle>& outputStrides,
85     const std::optional<ScalarType>& outputType,
86     at::Device device);
87 Tensor computeEmbedding(
88     const std::vector<ArgValue>& inputs,
89     const std::vector<ExprHandle>& outputShape,
90     const std::vector<ExprHandle>& outputStrides,
91     const std::optional<ScalarType>& outputType,
92     at::Device device);
93 
94 } // namespace torch::jit::tensorexpr
95