xref: /aosp_15_r20/external/pytorch/benchmarks/static_runtime/deep_wide_pt.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include "deep_wide_pt.h"
2 
3 #include <torch/csrc/jit/serialization/import_source.h>
4 #include <torch/script.h>
5 
6 namespace {
7 // No ReplaceNaN (this removes the constant in the model)
8 const std::string deep_wide_pt = R"JIT(
9 class DeepAndWide(Module):
10   __parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ]
11   __buffers__ = []
12   _mu : Tensor
13   _sigma : Tensor
14   _fc_w : Tensor
15   _fc_b : Tensor
16   training : bool
17   def forward(self: __torch__.DeepAndWide,
18     ad_emb_packed: Tensor,
19     user_emb: Tensor,
20     wide: Tensor) -> Tuple[Tensor]:
21     _0 = self._fc_b
22     _1 = self._fc_w
23     _2 = self._sigma
24     wide_offset = torch.add(wide, self._mu, alpha=1)
25     wide_normalized = torch.mul(wide_offset, _2)
26     wide_preproc = torch.clamp(wide_normalized, 0., 10.)
27     user_emb_t = torch.transpose(user_emb, 1, 2)
28     dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
29     dp = torch.flatten(dp_unflatten, 1, -1)
30     input = torch.cat([dp, wide_preproc], 1)
31     fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1)
32     return (torch.sigmoid(fc1),)
33 )JIT";
34 
35 const std::string trivial_model_1 = R"JIT(
36   def forward(self, a, b, c):
37       s = torch.tensor([[3, 3], [3, 3]])
38       return a + b * c + s
39 )JIT";
40 
41 const std::string leaky_relu_model_const = R"JIT(
42   def forward(self, input):
43       x = torch.leaky_relu(input, 0.1)
44       x = torch.leaky_relu(x, 0.1)
45       x = torch.leaky_relu(x, 0.1)
46       x = torch.leaky_relu(x, 0.1)
47       return torch.leaky_relu(x, 0.1)
48 )JIT";
49 
50 const std::string leaky_relu_model = R"JIT(
51   def forward(self, input, neg_slope):
52       x = torch.leaky_relu(input, neg_slope)
53       x = torch.leaky_relu(x, neg_slope)
54       x = torch.leaky_relu(x, neg_slope)
55       x = torch.leaky_relu(x, neg_slope)
56       return torch.leaky_relu(x, neg_slope)
57 )JIT";
58 
import_libs(std::shared_ptr<at::CompilationUnit> cu,const std::string & class_name,const std::shared_ptr<torch::jit::Source> & src,const std::vector<at::IValue> & tensor_table)59 void import_libs(
60     std::shared_ptr<at::CompilationUnit> cu,
61     const std::string& class_name,
62     const std::shared_ptr<torch::jit::Source>& src,
63     const std::vector<at::IValue>& tensor_table) {
64   torch::jit::SourceImporter si(
65       cu,
66       &tensor_table,
67       [&](const std::string& /* unused */)
68           -> std::shared_ptr<torch::jit::Source> { return src; },
69       /*version=*/2);
70   si.loadType(c10::QualifiedName(class_name));
71 }
72 } // namespace
73 
getDeepAndWideSciptModel(int num_features)74 torch::jit::Module getDeepAndWideSciptModel(int num_features) {
75   auto cu = std::make_shared<at::CompilationUnit>();
76   std::vector<at::IValue> constantTable;
77   import_libs(
78       cu,
79       "__torch__.DeepAndWide",
80       std::make_shared<torch::jit::Source>(deep_wide_pt),
81       constantTable);
82   c10::QualifiedName base("__torch__");
83   auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide"));
84 
85   torch::jit::Module mod(cu, clstype);
86 
87   mod.register_parameter("_mu", torch::randn({1, num_features}), false);
88   mod.register_parameter("_sigma", torch::randn({1, num_features}), false);
89   mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false);
90   mod.register_parameter("_fc_b", torch::randn({1}), false);
91 
92   // mod.dump(true, true, true);
93   return mod;
94 }
95 
getTrivialScriptModel()96 torch::jit::Module getTrivialScriptModel() {
97   torch::jit::Module module("m");
98   module.define(trivial_model_1);
99   return module;
100 }
101 
getLeakyReLUScriptModel()102 torch::jit::Module getLeakyReLUScriptModel() {
103   torch::jit::Module module("leaky_relu");
104   module.define(leaky_relu_model);
105   return module;
106 }
107 
getLeakyReLUConstScriptModel()108 torch::jit::Module getLeakyReLUConstScriptModel() {
109   torch::jit::Module module("leaky_relu_const");
110   module.define(leaky_relu_model_const);
111   return module;
112 }
113 
114 const std::string long_model = R"JIT(
115   def forward(self, a, b, c):
116       d = torch.relu(a * b)
117       e = torch.relu(a * c)
118       f = torch.relu(e * d)
119       g = torch.relu(f * f)
120       h = torch.relu(g * c)
121       return h
122 )JIT";
123 
getLongScriptModel()124 torch::jit::Module getLongScriptModel() {
125   torch::jit::Module module("m");
126   module.define(long_model);
127   return module;
128 }
129 
130 const std::string signed_log1p_model = R"JIT(
131   def forward(self, a):
132       b = torch.abs(a)
133       c = torch.log1p(b)
134       d = torch.sign(a)
135       e = d * c
136       return e
137 )JIT";
138 
getSignedLog1pModel()139 torch::jit::Module getSignedLog1pModel() {
140   torch::jit::Module module("signed_log1p");
141   module.define(signed_log1p_model);
142   return module;
143 }
144