xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/pointwise.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/kernel.h>
4 
5 namespace torch {
6 namespace jit {
7 namespace tensorexpr {
8 
9 TORCH_API Tensor computeSign(
10     const std::vector<ArgValue>& inputs,
11     const std::vector<ExprHandle>& outputShape,
12     const std::optional<std::vector<ExprHandle>>& outputStrides = std::nullopt);
13 
14 Tensor computeOneOperand(
15     const std::string& name,
16     const std::vector<ArgValue>& inputValues,
17     const std::vector<ExprHandle>& outputShape,
18     const std::vector<ExprHandle>& outputStrides,
19     const std::optional<ScalarType>& outputType,
20     const std::function<ExprHandle(const ExprHandle&)>& innerExpr,
21     const int checkParamTypes = kAllTypes);
22 Tensor computeTwoOperand(
23     const std::string& name,
24     const std::vector<ArgValue>& inputValues,
25     const std::vector<ExprHandle>& outputShape,
26     const std::vector<ExprHandle>& outputStrides,
27     const std::optional<ScalarType>& outputType,
28     const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
29         innerExpr);
30 Tensor computeTwoOperandWithAlpha(
31     const std::string& name,
32     const std::vector<ArgValue>& inputValues,
33     const std::vector<ExprHandle>& outputShape,
34     const std::vector<ExprHandle>& outputStrides,
35     const std::optional<ScalarType>& outputType,
36     const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
37         innerExpr);
38 Tensor computeConditionWithTwoOperand(
39     const std::string& name,
40     const std::vector<ArgValue>& inputValues,
41     const std::vector<ExprHandle>& outputShape,
42     const std::vector<ExprHandle>& outputStrides,
43     const std::optional<ScalarType>& outputType,
44     const std::function<
45         ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
46         innerExpr);
47 Tensor computeThreeOperand(
48     const std::string& name,
49     const std::vector<ArgValue>& inputValues,
50     const std::vector<ExprHandle>& outputShape,
51     const std::vector<ExprHandle>& outputStrides,
52     const std::optional<ScalarType>& outputType,
53     const std::function<
54         ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)>&
55         innerExpr,
56     bool promote_inputs = true);
57 Tensor computeFourOperand(
58     const std::string& name,
59     const std::vector<ArgValue>& inputValues,
60     const std::vector<ExprHandle>& outputShape,
61     const std::vector<ExprHandle>& outputStrides,
62     const std::optional<ScalarType>& outputType,
63     const std::function<ExprHandle(
64         const ExprHandle&,
65         const ExprHandle&,
66         const ExprHandle&,
67         const ExprHandle&)>& innerExpr);
68 Tensor computeNoop(
69     const std::vector<ArgValue>& inputs,
70     const std::vector<ExprHandle>& outputShape,
71     const std::vector<ExprHandle>& outputStrides,
72     const std::optional<ScalarType>& outputType,
73     at::Device device);
74 
75 Tensor computeScalar(
76     const std::string& name,
77     const std::vector<ArgValue>& inputValues,
78     const std::vector<ExprHandle>& outputShape,
79     const std::vector<ExprHandle>& outputStrides,
80     const std::optional<ScalarType>& outputType,
81     const std::function<ExprHandle(const ExprHandle&, const ExprHandle&)>&
82         innerExpr);
83 
84 } // namespace tensorexpr
85 } // namespace jit
86 } // namespace torch
87