1 #include <test/cpp/lazy/test_lazy_ops_util.h>
2
3 #include <torch/csrc/lazy/backend/lowering_context.h>
4 #include <torch/csrc/lazy/core/ir_builder.h>
5 #include <torch/csrc/lazy/core/ir_dump_util.h>
6 #include <torch/csrc/lazy/core/tensor_impl.h>
7
8 #include <iostream>
9 #include <string>
10
11 namespace torch {
12 namespace lazy {
13 namespace {
14
CreateIgnoredCounters()15 std::unordered_set<std::string>* CreateIgnoredCounters() {
16 std::unordered_set<std::string>* icounters =
17 new std::unordered_set<std::string>();
18 // Add below the counters whose name need to be ignored when doing
19 // is-any-counter-changed assertions.
20 icounters->insert("aten::rand");
21 return icounters;
22 }
23
24 } // namespace
25
GetIgnoredCounters()26 const std::unordered_set<std::string>* GetIgnoredCounters() {
27 static const std::unordered_set<std::string>* icounters =
28 CreateIgnoredCounters();
29 return icounters;
30 }
31
ToCpuTensor(const at::Tensor & tensor)32 at::Tensor ToCpuTensor(const at::Tensor& tensor) {
33 // tensor.to() implicitly triggers a sync if t.device=torch::kLazy.
34 return tensor.to(torch::kCPU);
35 }
36
CopyToDevice(const torch::Tensor & tensor,const torch::Device & device)37 torch::Tensor CopyToDevice(
38 const torch::Tensor& tensor,
39 const torch::Device& device) {
40 return tensor.clone().to(device, /*non_blocking=*/false, /*copy=*/true);
41 }
42
EqualValues(at::Tensor tensor1,at::Tensor tensor2)43 bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) {
44 tensor1 = ToCpuTensor(tensor1);
45 tensor2 = ToCpuTensor(tensor2);
46 if (torch::isnan(tensor1).any().item<bool>()) {
47 EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2)));
48 tensor1.nan_to_num_();
49 tensor2.nan_to_num_();
50 }
51 if (tensor1.sizes() != tensor2.sizes() ||
52 tensor1.dtype() != tensor2.dtype()) {
53 std::cerr << "Different shape:\n"
54 << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
55 << tensor2.dtype() << " " << tensor2.sizes() << "\n";
56 return false;
57 }
58 at::ScalarType type1 = tensor1.scalar_type();
59 at::ScalarType type2 = tensor2.scalar_type();
60 if (type1 != type2) {
61 tensor1 = tensor1.toType(type2);
62 }
63 bool equal = tensor1.equal(tensor2);
64 return equal;
65 }
66
EqualValuesNoElementTypeCheck(at::Tensor tensor1,at::Tensor tensor2)67 bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) {
68 tensor1 = ToCpuTensor(tensor1);
69 tensor2 = ToCpuTensor(tensor2);
70 if (tensor1.sizes() != tensor2.sizes()) {
71 std::cerr << "Different shape:\n"
72 << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
73 << tensor2.dtype() << " " << tensor2.sizes() << "\n";
74 return false;
75 }
76 at::ScalarType type1 = tensor1.scalar_type();
77 at::ScalarType type2 = tensor2.scalar_type();
78 if (type1 != type2) {
79 tensor1 = tensor1.toType(type2);
80 }
81 bool equal = tensor1.equal(tensor2);
82 return equal;
83 }
84
ForEachDevice(const std::function<void (const torch::Device &)> & devfn)85 void ForEachDevice(const std::function<void(const torch::Device&)>& devfn) {
86 // Currently TorchScript backend only supports one type of hardware per
87 // process, which is set by env. And the ordinal is always 0 given distributed
88 // training/ multi-device is not supported yet.
89 auto device = torch::lazy::BackendDevice();
90 torch::Device torch_device = torch::lazy::backendDeviceToAtenDevice(device);
91 devfn(torch_device);
92 }
93
CloseValues(at::Tensor tensor1,at::Tensor tensor2,double rtol,double atol)94 bool CloseValues(
95 at::Tensor tensor1,
96 at::Tensor tensor2,
97 double rtol,
98 double atol) {
99 tensor1 = ToCpuTensor(tensor1);
100 tensor2 = ToCpuTensor(tensor2);
101 if (torch::isnan(tensor1).any().item<bool>()) {
102 EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2)));
103 tensor1.nan_to_num_();
104 tensor2.nan_to_num_();
105 }
106 if (tensor1.sizes() != tensor2.sizes() ||
107 tensor1.dtype() != tensor2.dtype()) {
108 std::cerr << "Different shape:\n"
109 << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
110 << tensor2.dtype() << " " << tensor2.sizes() << "\n";
111 return false;
112 }
113 bool equal = tensor1.allclose(tensor2, rtol, atol);
114 return equal;
115 }
116
GetTensorTextGraph(at::Tensor tensor)117 std::string GetTensorTextGraph(at::Tensor tensor) {
118 torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
119 return torch::lazy::DumpUtil::ToText({lazy_tensor->GetIrValue().node.get()});
120 }
121
GetTensorDotGraph(at::Tensor tensor)122 std::string GetTensorDotGraph(at::Tensor tensor) {
123 torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
124 return torch::lazy::DumpUtil::ToDot({lazy_tensor->GetIrValue().node.get()});
125 }
126
TestBackward(const std::vector<torch::Tensor> & inputs,const torch::Device & device,const std::function<torch::Tensor (const std::vector<torch::Tensor> &)> & testfn,double rtol,double atol,int derivative_level)127 void TestBackward(
128 const std::vector<torch::Tensor>& inputs,
129 const torch::Device& device,
130 const std::function<torch::Tensor(const std::vector<torch::Tensor>&)>&
131 testfn,
132 double rtol,
133 double atol,
134 int derivative_level) {
135 std::vector<torch::Tensor> input_vars;
136 std::vector<torch::Tensor> xinput_vars;
137 std::vector<torch::Tensor> inputs_w_grad;
138 std::vector<torch::Tensor> xinputs_w_grad;
139 for (size_t i = 0; i < inputs.size(); ++i) {
140 const torch::Tensor& input = inputs[i];
141 if (input.defined()) {
142 torch::Tensor oinput =
143 input.clone().detach().set_requires_grad(input.requires_grad());
144 input_vars.push_back(oinput);
145
146 torch::Tensor xinput = CopyToDevice(input, device)
147 .detach()
148 .set_requires_grad(input.requires_grad());
149 xinput_vars.push_back(xinput);
150 if (input.requires_grad()) {
151 inputs_w_grad.push_back(oinput);
152 xinputs_w_grad.push_back(xinput);
153 }
154 } else {
155 input_vars.emplace_back();
156 xinput_vars.emplace_back();
157 }
158 }
159
160 torch::Tensor output = testfn(input_vars);
161 torch::Tensor xoutput = testfn(xinput_vars);
162 torch::lazy::AllClose(output, xoutput, rtol, atol);
163
164 std::vector<torch::Tensor> outs = {output};
165 std::vector<torch::Tensor> xouts = {xoutput};
166 for (int d = 1; d <= derivative_level; ++d) {
167 // Check grad of sum(outs) w.r.t inputs_w_grad.
168 torch::Tensor sum = torch::zeros_like(outs[0]).sum();
169 torch::Tensor xsum = torch::zeros_like(xouts[0]).sum();
170 for (size_t i = 0; i < outs.size(); ++i) {
171 if (outs[i].requires_grad()) {
172 sum += outs[i].sum();
173 xsum += xouts[i].sum();
174 }
175 }
176 // Calculating higher order derivative requires create_graph=true
177 bool create_graph = d != derivative_level;
178 outs = torch::autograd::grad(
179 {sum},
180 inputs_w_grad,
181 /*grad_outputs=*/{},
182 /*retain_graph=*/std::nullopt,
183 /*create_graph=*/create_graph,
184 /*allow_unused=*/true);
185 xouts = torch::autograd::grad(
186 {xsum},
187 xinputs_w_grad,
188 /*grad_outputs=*/{},
189 /*retain_graph=*/std::nullopt,
190 /*create_graph=*/create_graph,
191 /*allow_unused=*/true);
192 for (size_t i = 0; i < outs.size(); ++i) {
193 ASSERT_EQ(outs[i].defined(), xouts[i].defined());
194 if (outs[i].defined()) {
195 AllClose(outs[i], xouts[i], rtol, atol);
196 }
197 }
198 }
199 }
200
201 } // namespace lazy
202 } // namespace torch
203