1 #include <torch/csrc/jit/frontend/builtin_functions.h>
2
3 #include <ATen/code_template.h>
4 #include <caffe2/serialize/versions.h>
5 #include <torch/csrc/api/include/torch/jit.h>
6 #include <torch/csrc/jit/frontend/resolver.h>
7
8 namespace torch::jit {
9
10 auto scalar_operators_source = at::jit::CodeTemplate(
11 R"SCRIPT(
12 def mul(a : ${Scalar}, b : Tensor) -> Tensor:
13 return b * a
14 def add(a : ${Scalar}, b : Tensor) -> Tensor:
15 return b + a
16 def ne(a : ${Scalar}, b : Tensor) -> Tensor:
17 return b != a
18 def eq(a : ${Scalar}, b : Tensor) -> Tensor:
19 return b == a
20 def sub(a : ${Scalar}, b : Tensor) -> Tensor:
21 return torch.neg(b) + a
22 def div(a : ${Scalar}, b : Tensor) -> Tensor:
23 return torch.reciprocal(b) * a
24 )SCRIPT");
25
26 auto scalar_operators_no_complex_source = at::jit::CodeTemplate(
27 R"SCRIPT(
28 def lt(a : ${Scalar}, b : Tensor) -> Tensor:
29 return b > a
30 def le(a : ${Scalar}, b : Tensor) -> Tensor:
31 return b >= a
32 def gt(a : ${Scalar}, b : Tensor) -> Tensor:
33 return b < a
34 def ge(a : ${Scalar}, b : Tensor) -> Tensor:
35 return b <= a
36 )SCRIPT");
37
38 auto _ntuple_ops = at::jit::CodeTemplate(
39 R"SCRIPT(
40 def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
41 return x
42 )SCRIPT");
43
44 auto floordiv = at::jit::CodeTemplate(
45 R"SCRIPT(
46 def floordiv(self : Tensor, other : ${Rhs_Type}) -> Tensor:
47 return torch.floor_divide(self, other)
48 )SCRIPT");
49
50 auto tensor_properties =
51 R"SCRIPT(
52 def ndim(a : Tensor) -> int:
53 return a.dim()
54 def T(a : Tensor) -> Tensor:
55 return a.numpy_T()
56 def H(a : Tensor) -> Tensor:
57 return a.matrix_H()
58 def mT(a : Tensor) -> Tensor:
59 return a.mT
60 def mH(a : Tensor) -> Tensor:
61 return a.mH
62 def shape(a : Tensor) -> List[int]:
63 return a.size()
64 )SCRIPT";
65
66 // _assert_int_or_pair is only here for backwards-compatibility with the
67 // aten::_assert_int_or_pair op which was removed once we were able to compile
68 // torch.nn.functional.assert_int_or_pair
69 // list_with_default also needs to be here for BC
70 auto aten_ops =
71 R"SCRIPT(
72 def _assert_int_or_pair(vals: List[int], name: str, message: str):
73 pass
74 def list_with_default(out_size: List[int], defaults: List[int]):
75 assert len(defaults) > len(out_size)
76 return out_size
77 def _assert(condition : bool, message : str):
78 assert condition, message
79 # existing device operator is registered with input name `a`, which prevents
80 # torch.device(type="cuda") from working. add shim-layer here
81 def device(type: str):
82 return torch.device(type)
83 def type(self: Tensor, dtype: int, non_blocking: bool=False, copy: bool=False) -> Tensor:
84 return self.to(dtype, non_blocking, copy)
85 )SCRIPT";
86
87 // an additional overload for Tensor variant of _assert
88 const auto aten_ops_additional =
89 R"SCRIPT(
90 def _assert(condition : Tensor, message : str):
91 assert bool(condition), message
92 def __contains__(self: str, key: str):
93 return self.find(key, 0, len(self)) != -1
94 )SCRIPT";
95
96 struct BuiltinFunctionRegistry {
getAllBuiltinFunctionsFortorch::jit::BuiltinFunctionRegistry97 const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
98 const static std::vector<Function*> empty;
99 // when initializing the builtin function library, we will re-enter
100 // getAllBuiltinFunctionsFor since it is called in the compiler to
101 // lookup builtins and initializing the builtin functions calls the
102 // compiler. To avoid deadlocking, we use a recursive mutex (same thread can
103 // re-lock, the mutex without waiting), and report no loaded builtins during
104 // init.
105 std::lock_guard<std::recursive_mutex> guard(mutex);
106 if (state == INTIIALIZING) {
107 return empty;
108 } else if (state == UNINITIALIZED) {
109 state = INTIIALIZING;
110 loadBuiltinFunctions();
111 state = INITIALIZED;
112 }
113 AT_ASSERT(state == INITIALIZED);
114 auto it = builtins_by_name_.find(name);
115 if (it == builtins_by_name_.end())
116 return empty;
117 return it->second;
118 }
119
120 private:
loadSourcetorch::jit::BuiltinFunctionRegistry121 void loadSource(const std::string& source, const std::string& the_namespace) {
122 std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
123 modules.emplace_back(cu);
124 cu->define(std::nullopt, source, nativeResolver(), /*self=*/nullptr);
125 for (auto& method : cu->get_functions()) {
126 builtins_by_name_[Symbol::fromQualString(
127 the_namespace + "::" + method->name())]
128 .push_back(method);
129 }
130 }
131
loadBuiltinFunctionstorch::jit::BuiltinFunctionRegistry132 void loadBuiltinFunctions() {
133 for (auto scalar : {"float", "int", "complex"}) {
134 at::jit::TemplateEnv env;
135 env.s("Scalar", scalar);
136 loadSource(scalar_operators_source.format(env), "aten");
137 }
138
139 for (auto scalar : {"float", "int"}) {
140 at::jit::TemplateEnv env;
141 env.s("Scalar", scalar);
142 loadSource(scalar_operators_no_complex_source.format(env), "aten");
143 }
144
145 using str_pair = std::pair<std::string, std::string>;
146 const std::vector<str_pair> name_len = {
147 str_pair("single", "1"),
148 str_pair("pair", "2"),
149 str_pair("triple", "3"),
150 str_pair("quadruple", "4"),
151 };
152 for (const auto scalar : {"float", "int"}) {
153 for (const auto& pair : name_len) {
154 at::jit::TemplateEnv env;
155 env.s("Scalar", scalar);
156 env.s("name", pair.first);
157 env.s("Length", pair.second);
158 loadSource(_ntuple_ops.format(env), "aten");
159 }
160 }
161 for (auto rhs : {"number", "Tensor"}) {
162 at::jit::TemplateEnv env;
163 env.s("Rhs_Type", rhs);
164 loadSource(floordiv.format(env), "aten");
165 }
166
167 loadSource(aten_ops, "aten");
168 loadSource(aten_ops_additional, "aten");
169
170 // These are under `prim` instead of `aten` since they exist to bind certain
171 // tensor property getters to correpsonding methods
172 loadSource(tensor_properties, "prim");
173 }
174 enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
175 std::recursive_mutex mutex;
176 std::vector<std::shared_ptr<CompilationUnit>> modules;
177 std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name_;
178 };
179
getAllBuiltinFunctionsFor(Symbol name)180 const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
181 static BuiltinFunctionRegistry registry;
182 return registry.getAllBuiltinFunctionsFor(name);
183 }
184
185 } // namespace torch::jit
186