xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/builtin_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/builtin_functions.h>
2 
3 #include <ATen/code_template.h>
4 #include <caffe2/serialize/versions.h>
5 #include <torch/csrc/api/include/torch/jit.h>
6 #include <torch/csrc/jit/frontend/resolver.h>
7 
8 namespace torch::jit {
9 
10 auto scalar_operators_source = at::jit::CodeTemplate(
11     R"SCRIPT(
12 def mul(a : ${Scalar}, b : Tensor) -> Tensor:
13   return b * a
14 def add(a : ${Scalar}, b : Tensor) -> Tensor:
15   return b + a
16 def ne(a : ${Scalar}, b : Tensor) -> Tensor:
17   return b != a
18 def eq(a : ${Scalar}, b : Tensor) -> Tensor:
19   return b == a
20 def sub(a : ${Scalar}, b : Tensor) -> Tensor:
21   return torch.neg(b) + a
22 def div(a : ${Scalar}, b : Tensor) -> Tensor:
23   return torch.reciprocal(b) * a
24 )SCRIPT");
25 
26 auto scalar_operators_no_complex_source = at::jit::CodeTemplate(
27     R"SCRIPT(
28 def lt(a : ${Scalar}, b : Tensor) -> Tensor:
29   return b > a
30 def le(a : ${Scalar}, b : Tensor) -> Tensor:
31   return b >= a
32 def gt(a : ${Scalar}, b : Tensor) -> Tensor:
33   return b < a
34 def ge(a : ${Scalar}, b : Tensor) -> Tensor:
35   return b <= a
36 )SCRIPT");
37 
38 auto _ntuple_ops = at::jit::CodeTemplate(
39     R"SCRIPT(
40 def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
41   return x
42 )SCRIPT");
43 
44 auto floordiv = at::jit::CodeTemplate(
45     R"SCRIPT(
46 def floordiv(self : Tensor, other : ${Rhs_Type}) -> Tensor:
47   return torch.floor_divide(self, other)
48 )SCRIPT");
49 
50 auto tensor_properties =
51     R"SCRIPT(
52 def ndim(a : Tensor) -> int:
53   return a.dim()
54 def T(a : Tensor) -> Tensor:
55   return a.numpy_T()
56 def H(a : Tensor) -> Tensor:
57   return a.matrix_H()
58 def mT(a : Tensor) -> Tensor:
59   return a.mT
60 def mH(a : Tensor) -> Tensor:
61   return a.mH
62 def shape(a : Tensor) -> List[int]:
63   return a.size()
64 )SCRIPT";
65 
66 // _assert_int_or_pair is only here for backwards-compatibility with the
67 // aten::_assert_int_or_pair op which was removed once we were able to compile
68 // torch.nn.functional.assert_int_or_pair
69 // list_with_default also needs to be here for BC
70 auto aten_ops =
71     R"SCRIPT(
72 def _assert_int_or_pair(vals: List[int], name: str, message: str):
73   pass
74 def list_with_default(out_size: List[int], defaults: List[int]):
75   assert len(defaults) > len(out_size)
76   return out_size
77 def _assert(condition : bool, message : str):
78   assert condition, message
79 # existing device operator is registered with input name `a`, which prevents
80 # torch.device(type="cuda") from working. add shim-layer here
81 def device(type: str):
82   return torch.device(type)
83 def type(self: Tensor, dtype: int, non_blocking: bool=False, copy: bool=False) -> Tensor:
84   return self.to(dtype, non_blocking, copy)
85 )SCRIPT";
86 
87 // an additional overload for Tensor variant of _assert
88 const auto aten_ops_additional =
89     R"SCRIPT(
90 def _assert(condition : Tensor, message : str):
91   assert bool(condition), message
92 def __contains__(self: str, key: str):
93     return self.find(key, 0, len(self)) != -1
94 )SCRIPT";
95 
96 struct BuiltinFunctionRegistry {
getAllBuiltinFunctionsFortorch::jit::BuiltinFunctionRegistry97   const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
98     const static std::vector<Function*> empty;
99     // when initializing the builtin function library, we will re-enter
100     // getAllBuiltinFunctionsFor since it is called in the compiler to
101     // lookup builtins and initializing the builtin functions calls the
102     // compiler. To avoid deadlocking, we use a recursive mutex (same thread can
103     // re-lock, the mutex without waiting), and report no loaded builtins during
104     // init.
105     std::lock_guard<std::recursive_mutex> guard(mutex);
106     if (state == INTIIALIZING) {
107       return empty;
108     } else if (state == UNINITIALIZED) {
109       state = INTIIALIZING;
110       loadBuiltinFunctions();
111       state = INITIALIZED;
112     }
113     AT_ASSERT(state == INITIALIZED);
114     auto it = builtins_by_name_.find(name);
115     if (it == builtins_by_name_.end())
116       return empty;
117     return it->second;
118   }
119 
120  private:
loadSourcetorch::jit::BuiltinFunctionRegistry121   void loadSource(const std::string& source, const std::string& the_namespace) {
122     std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
123     modules.emplace_back(cu);
124     cu->define(std::nullopt, source, nativeResolver(), /*self=*/nullptr);
125     for (auto& method : cu->get_functions()) {
126       builtins_by_name_[Symbol::fromQualString(
127                             the_namespace + "::" + method->name())]
128           .push_back(method);
129     }
130   }
131 
loadBuiltinFunctionstorch::jit::BuiltinFunctionRegistry132   void loadBuiltinFunctions() {
133     for (auto scalar : {"float", "int", "complex"}) {
134       at::jit::TemplateEnv env;
135       env.s("Scalar", scalar);
136       loadSource(scalar_operators_source.format(env), "aten");
137     }
138 
139     for (auto scalar : {"float", "int"}) {
140       at::jit::TemplateEnv env;
141       env.s("Scalar", scalar);
142       loadSource(scalar_operators_no_complex_source.format(env), "aten");
143     }
144 
145     using str_pair = std::pair<std::string, std::string>;
146     const std::vector<str_pair> name_len = {
147         str_pair("single", "1"),
148         str_pair("pair", "2"),
149         str_pair("triple", "3"),
150         str_pair("quadruple", "4"),
151     };
152     for (const auto scalar : {"float", "int"}) {
153       for (const auto& pair : name_len) {
154         at::jit::TemplateEnv env;
155         env.s("Scalar", scalar);
156         env.s("name", pair.first);
157         env.s("Length", pair.second);
158         loadSource(_ntuple_ops.format(env), "aten");
159       }
160     }
161     for (auto rhs : {"number", "Tensor"}) {
162       at::jit::TemplateEnv env;
163       env.s("Rhs_Type", rhs);
164       loadSource(floordiv.format(env), "aten");
165     }
166 
167     loadSource(aten_ops, "aten");
168     loadSource(aten_ops_additional, "aten");
169 
170     // These are under `prim` instead of `aten` since they exist to bind certain
171     // tensor property getters to correpsonding methods
172     loadSource(tensor_properties, "prim");
173   }
174   enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
175   std::recursive_mutex mutex;
176   std::vector<std::shared_ptr<CompilationUnit>> modules;
177   std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name_;
178 };
179 
getAllBuiltinFunctionsFor(Symbol name)180 const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
181   static BuiltinFunctionRegistry registry;
182   return registry.getAllBuiltinFunctionsFor(name);
183 }
184 
185 } // namespace torch::jit
186