1 #include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
2
3 #include <ATen/core/stack.h>
4 #include <c10/macros/Export.h>
5 #include <torch/csrc/jit/api/compilation_unit.h>
6 #include <torch/csrc/jit/api/function_impl.h>
7 #include <torch/csrc/jit/frontend/ir_emitter.h>
8 #include <torch/csrc/jit/ir/ir.h>
9 #include <torch/csrc/jit/operator_upgraders/upgraders.h>
10 #include <torch/csrc/jit/serialization/export_bytecode.h>
11 #include <string>
12 #include <unordered_map>
13
14 namespace torch::jit {
15
16 static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
17 {"logspace_0_8", R"SCRIPT(
18 def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int],
19 device: Optional[Device], pin_memory: Optional[bool]):
20 if (steps is None):
21 return torch.logspace(start=start, end=end, steps=100, base=base, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
22 return torch.logspace(start=start, end=end, steps=steps, base=base, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
23 )SCRIPT"},
24 {"logspace_out_0_8", R"SCRIPT(
25 def logspace_out_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, out: Tensor):
26 if (steps is None):
27 return torch.logspace(start=start, end=end, steps=100, base=base, out=out)
28 return torch.logspace(start=start, end=end, steps=steps, base=base, out=out)
29 )SCRIPT"},
30 {"linspace_0_7", R"SCRIPT(
31 def linspace_0_7(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], *, dtype: Optional[int], layout: Optional[int],
32 device: Optional[Device], pin_memory: Optional[bool]):
33 if (steps is None):
34 return torch.linspace(start=start, end=end, steps=100, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
35 return torch.linspace(start=start, end=end, steps=steps, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
36 )SCRIPT"},
37 {"linspace_out_0_7", R"SCRIPT(
38 def linspace_out_0_7(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], *, out: Tensor):
39 if (steps is None):
40 return torch.linspace(start=start, end=end, steps=100, out=out)
41 return torch.linspace(start=start, end=end, steps=steps, out=out)
42 )SCRIPT"},
43 {"div_Tensor_0_3", R"SCRIPT(
44 def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
45 if (self.is_floating_point() or other.is_floating_point()):
46 return self.true_divide(other)
47 return self.divide(other, rounding_mode='trunc')
48 )SCRIPT"},
49 {"div_Tensor_mode_0_3", R"SCRIPT(
50 def div_Tensor_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None) -> Tensor:
51 return self.divide(other, rounding_mode=rounding_mode)
52 )SCRIPT"},
53 {"div_Scalar_0_3", R"SCRIPT(
54 def div_Scalar_0_3(self: Tensor, other: number) -> Tensor:
55 if (self.is_floating_point() or isinstance(other, float)):
56 return self.true_divide(other)
57 return self.divide(other, rounding_mode='trunc')
58 )SCRIPT"},
59 {"div_Scalar_mode_0_3", R"SCRIPT(
60 def div_Scalar_mode_0_3(self: Tensor, other: number, *, rounding_mode: Optional[str]=None) -> Tensor:
61 return self.divide(other, rounding_mode=rounding_mode)
62 )SCRIPT"},
63 {"div_out_0_3", R"SCRIPT(
64 def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor:
65 if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()):
66 return self.true_divide(other, out=out)
67 return self.divide(other, rounding_mode='trunc', out=out)
68 )SCRIPT"},
69 {"div_out_mode_0_3", R"SCRIPT(
70 def div_out_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None, out: Tensor) -> Tensor:
71 return self.divide(other, rounding_mode=rounding_mode, out=out)
72 )SCRIPT"},
73 {"div__Tensor_0_3", R"SCRIPT(
74 def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
75 if (self.is_floating_point() or other.is_floating_point()):
76 return self.true_divide_(other)
77 return self.divide_(other, rounding_mode='trunc')
78 )SCRIPT"},
79 {"div__Tensor_mode_0_3", R"SCRIPT(
80 def div__Tensor_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None) -> Tensor:
81 return self.divide_(other, rounding_mode=rounding_mode)
82 )SCRIPT"},
83 {"div__Scalar_0_3", R"SCRIPT(
84 def div__Scalar_0_3(self: Tensor, other: number) -> Tensor:
85 if (self.is_floating_point() or isinstance(other, float)):
86 return self.true_divide_(other)
87 return self.divide_(other, rounding_mode='trunc')
88 )SCRIPT"},
89 {"div__Scalar_mode_0_3", R"SCRIPT(
90 def div__Scalar_mode_0_3(self: Tensor, other: number, *, rounding_mode: Optional[str]=None) -> Tensor:
91 return self.divide_(other, rounding_mode=rounding_mode)
92 )SCRIPT"},
93 {"full_names_0_4", R"SCRIPT(
94 def full_names_0_4(size:List[int], fill_value:number, *, names:Optional[List[str]]=None,
95 dtype:Optional[int]=None, layout:Optional[int]=None, device:Optional[Device]=None,
96 pin_memory:Optional[bool]=None) -> Tensor:
97 return torch.full(size, fill_value, names=names, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
98 )SCRIPT"},
99 {"full_0_4", R"SCRIPT(
100 def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
101 layout:Optional[int]=None, device:Optional[Device]=None,
102 pin_memory:Optional[bool]=None) -> Tensor:
103 if dtype is None:
104 fill_value = float(fill_value)
105 return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
106 )SCRIPT"},
107 {"full_out_0_4", R"SCRIPT(
108 def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
109 return torch.full(size, fill_value, out=out)
110 )SCRIPT"},
111 {"gelu_0_9", R"SCRIPT(
112 def gelu_0_9(self: Tensor) -> Tensor:
113 return torch.gelu(self, approximate='none')
114 )SCRIPT"},
115 {"gelu_out_0_9", R"SCRIPT(
116 def gelu_out_0_9(self: Tensor, *, out: Tensor) -> Tensor:
117 return torch.gelu(self, approximate='none', out=out)
118 )SCRIPT"},
119 });
120
create_upgrader_graph(const std::string & upgrader_name,const std::string & upgrader_body)121 std::shared_ptr<Graph> create_upgrader_graph(
122 const std::string& upgrader_name,
123 const std::string& upgrader_body) {
124 auto cu = std::make_shared<CompilationUnit>();
125 cu->define(std::nullopt, upgrader_body, nativeResolver(), nullptr);
126 Function& jitFunc = cu->get_function(upgrader_name);
127 GraphFunction& graphFunction = toGraphFunction(jitFunc);
128 return graphFunction.graph();
129 }
130
131 std::unordered_map<std::string, std::shared_ptr<Graph>>
generate_upgraders_graph()132 generate_upgraders_graph() {
133 std::unordered_map<std::string, std::shared_ptr<Graph>> populate_content;
134 for (const auto& entry : kUpgradersEntryMap) {
135 auto upgrader_graph = create_upgrader_graph(entry.first, entry.second);
136 populate_content.insert(std::make_pair(entry.first, upgrader_graph));
137 }
138 return populate_content;
139 }
140
populate_upgraders_graph_map()141 void populate_upgraders_graph_map() {
142 if (!is_upgraders_map_populated()) {
143 auto graphs = generate_upgraders_graph();
144 populate_upgraders_map(std::move(graphs));
145 }
146 }
147
get_upgraders_entry_map()148 std::unordered_map<std::string, std::string> get_upgraders_entry_map() {
149 return kUpgradersEntryMap;
150 }
151
152 } // namespace torch::jit
153