xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/decomposition_registry_util.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 /**
3  * @generated
4  * This is an auto-generated file. Please do not modify it by hand.
5  * To re-generate, please run:
6  * cd ~/pytorch && python torchgen/decompositions/gen_jit_decompositions.py
7  */
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/inliner.h>
10 #include <torch/csrc/jit/runtime/decomposition_registry_util.h>
11 #include <torch/csrc/jit/runtime/operator.h>
12 
13 namespace torch::jit {
14 
15 const std::string decomp_funcs =
16     R"(def var_decomposition(input: Tensor,
17     dim: Optional[List[int]]=None,
18     correction: Union[float, int, NoneType, bool]=None,
19     keepdim: bool=False) -> Tensor:
20   _0 = uninitialized(float)
21   if torch.__is__(dim, None):
22     dim0 = annotate(List[int], [])
23   else:
24     dim0 = unchecked_cast(List[int], dim)
25   if torch.eq(torch.len(dim0), 0):
26     n = torch.numel(input)
27   else:
28     n0 = 1
29     for _1 in range(torch.len(dim0)):
30       dim_i = dim0[_1]
31       n1 = torch.mul(n0, (torch.size(input))[dim_i])
32       n0 = n1
33     n = n0
34   mean = torch.mean(input, dim0, True)
35   sub = torch.sub(input, mean)
36   sq = torch.mul(sub, sub)
37   sum = torch.sum(sq, dim0, keepdim)
38   if torch.__is__(correction, None):
39     denom = float(torch.sub(n, 1))
40   else:
41     correction0 = unchecked_cast(Union[float, int, bool], correction)
42     _2 = isinstance(correction0, int)
43     if _2:
44       correction1 = unchecked_cast(int, correction0)
45       denom0 = float(torch.sub(n, correction1))
46     else:
47       correction2 = unchecked_cast(Union[float, bool], correction0)
48       _3 = isinstance(correction2, float)
49       if _3:
50         correction3 = unchecked_cast(float, correction2)
51         denom2 = torch.sub(float(n), correction3)
52         denom1 = denom2
53       else:
54         ops.prim.RaiseException("correction must be int or float", "builtins.RuntimeError")
55         denom1 = _0
56       denom0 = denom1
57     denom = denom0
58   _4 = torch.div(sum, ops.prim.max(0, denom))
59   return _4
60 
61 def var(input: Tensor,
62     unbiased: bool=True) -> Tensor:
63   if unbiased:
64     _0 = 1
65   else:
66     _0 = 0
67   _1 = uninitialized(float)
68   n = torch.numel(input)
69   mean = torch.mean(input, annotate(List[int], []), True)
70   sub = torch.sub(input, mean)
71   sq = torch.mul(sub, sub)
72   sum = torch.sum(sq, annotate(List[int], []))
73   _2 = isinstance(_0, int)
74   if _2:
75     denom = float(torch.sub(n, _0))
76   else:
77     correction = unchecked_cast(Union[float, bool], _0)
78     _3 = isinstance(correction, float)
79     if _3:
80       correction0 = unchecked_cast(float, correction)
81       denom0 = torch.sub(float(n), correction0)
82     else:
83       ops.prim.RaiseException("correction must be int or float", "builtins.RuntimeError")
84       denom0 = _1
85     denom = denom0
86   _4 = torch.div(sum, ops.prim.max(0, denom))
87   return _4
88 
89 )";
90 
GetSerializedDecompositions()91 const std::string& GetSerializedDecompositions() {
92   return decomp_funcs;
93 }
94 
GetDecompositionMapping()95 const OperatorMap<std::string>& GetDecompositionMapping() {
96   // clang-format off
97  static const OperatorMap<std::string> decomposition_mapping {
98     {"aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "var_decomposition"},
99     {"aten::var(Tensor self, bool unbiased=True) -> Tensor", "var"},
100   };
101   // clang-format on
102 
103   return decomposition_mapping;
104 }
105 
106 } // namespace torch::jit
107