1 #include "deep_wide_pt.h"
2
3 #include <torch/csrc/jit/serialization/import_source.h>
4 #include <torch/script.h>
5
6 namespace {
7 // No ReplaceNaN (this removes the constant in the model)
8 const std::string deep_wide_pt = R"JIT(
9 class DeepAndWide(Module):
10 __parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ]
11 __buffers__ = []
12 _mu : Tensor
13 _sigma : Tensor
14 _fc_w : Tensor
15 _fc_b : Tensor
16 training : bool
17 def forward(self: __torch__.DeepAndWide,
18 ad_emb_packed: Tensor,
19 user_emb: Tensor,
20 wide: Tensor) -> Tuple[Tensor]:
21 _0 = self._fc_b
22 _1 = self._fc_w
23 _2 = self._sigma
24 wide_offset = torch.add(wide, self._mu, alpha=1)
25 wide_normalized = torch.mul(wide_offset, _2)
26 wide_preproc = torch.clamp(wide_normalized, 0., 10.)
27 user_emb_t = torch.transpose(user_emb, 1, 2)
28 dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t)
29 dp = torch.flatten(dp_unflatten, 1, -1)
30 input = torch.cat([dp, wide_preproc], 1)
31 fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1)
32 return (torch.sigmoid(fc1),)
33 )JIT";
34
35 const std::string trivial_model_1 = R"JIT(
36 def forward(self, a, b, c):
37 s = torch.tensor([[3, 3], [3, 3]])
38 return a + b * c + s
39 )JIT";
40
41 const std::string leaky_relu_model_const = R"JIT(
42 def forward(self, input):
43 x = torch.leaky_relu(input, 0.1)
44 x = torch.leaky_relu(x, 0.1)
45 x = torch.leaky_relu(x, 0.1)
46 x = torch.leaky_relu(x, 0.1)
47 return torch.leaky_relu(x, 0.1)
48 )JIT";
49
50 const std::string leaky_relu_model = R"JIT(
51 def forward(self, input, neg_slope):
52 x = torch.leaky_relu(input, neg_slope)
53 x = torch.leaky_relu(x, neg_slope)
54 x = torch.leaky_relu(x, neg_slope)
55 x = torch.leaky_relu(x, neg_slope)
56 return torch.leaky_relu(x, neg_slope)
57 )JIT";
58
import_libs(std::shared_ptr<at::CompilationUnit> cu,const std::string & class_name,const std::shared_ptr<torch::jit::Source> & src,const std::vector<at::IValue> & tensor_table)59 void import_libs(
60 std::shared_ptr<at::CompilationUnit> cu,
61 const std::string& class_name,
62 const std::shared_ptr<torch::jit::Source>& src,
63 const std::vector<at::IValue>& tensor_table) {
64 torch::jit::SourceImporter si(
65 cu,
66 &tensor_table,
67 [&](const std::string& /* unused */)
68 -> std::shared_ptr<torch::jit::Source> { return src; },
69 /*version=*/2);
70 si.loadType(c10::QualifiedName(class_name));
71 }
72 } // namespace
73
getDeepAndWideSciptModel(int num_features)74 torch::jit::Module getDeepAndWideSciptModel(int num_features) {
75 auto cu = std::make_shared<at::CompilationUnit>();
76 std::vector<at::IValue> constantTable;
77 import_libs(
78 cu,
79 "__torch__.DeepAndWide",
80 std::make_shared<torch::jit::Source>(deep_wide_pt),
81 constantTable);
82 c10::QualifiedName base("__torch__");
83 auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide"));
84
85 torch::jit::Module mod(cu, clstype);
86
87 mod.register_parameter("_mu", torch::randn({1, num_features}), false);
88 mod.register_parameter("_sigma", torch::randn({1, num_features}), false);
89 mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false);
90 mod.register_parameter("_fc_b", torch::randn({1}), false);
91
92 // mod.dump(true, true, true);
93 return mod;
94 }
95
getTrivialScriptModel()96 torch::jit::Module getTrivialScriptModel() {
97 torch::jit::Module module("m");
98 module.define(trivial_model_1);
99 return module;
100 }
101
getLeakyReLUScriptModel()102 torch::jit::Module getLeakyReLUScriptModel() {
103 torch::jit::Module module("leaky_relu");
104 module.define(leaky_relu_model);
105 return module;
106 }
107
getLeakyReLUConstScriptModel()108 torch::jit::Module getLeakyReLUConstScriptModel() {
109 torch::jit::Module module("leaky_relu_const");
110 module.define(leaky_relu_model_const);
111 return module;
112 }
113
114 const std::string long_model = R"JIT(
115 def forward(self, a, b, c):
116 d = torch.relu(a * b)
117 e = torch.relu(a * c)
118 f = torch.relu(e * d)
119 g = torch.relu(f * f)
120 h = torch.relu(g * c)
121 return h
122 )JIT";
123
getLongScriptModel()124 torch::jit::Module getLongScriptModel() {
125 torch::jit::Module module("m");
126 module.define(long_model);
127 return module;
128 }
129
130 const std::string signed_log1p_model = R"JIT(
131 def forward(self, a):
132 b = torch.abs(a)
133 c = torch.log1p(b)
134 d = torch.sign(a)
135 e = d * c
136 return e
137 )JIT";
138
getSignedLog1pModel()139 torch::jit::Module getSignedLog1pModel() {
140 torch::jit::Module module("signed_log1p");
141 module.define(signed_log1p_model);
142 return module;
143 }
144