xref: /aosp_15_r20/external/pytorch/test/cpp/lazy/test_lazy_ops_util.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <test/cpp/lazy/test_lazy_ops_util.h>
2 
3 #include <torch/csrc/lazy/backend/lowering_context.h>
4 #include <torch/csrc/lazy/core/ir_builder.h>
5 #include <torch/csrc/lazy/core/ir_dump_util.h>
6 #include <torch/csrc/lazy/core/tensor_impl.h>
7 
8 #include <iostream>
9 #include <string>
10 
11 namespace torch {
12 namespace lazy {
13 namespace {
14 
CreateIgnoredCounters()15 std::unordered_set<std::string>* CreateIgnoredCounters() {
16   std::unordered_set<std::string>* icounters =
17       new std::unordered_set<std::string>();
18   // Add below the counters whose name need to be ignored when doing
19   // is-any-counter-changed assertions.
20   icounters->insert("aten::rand");
21   return icounters;
22 }
23 
24 } // namespace
25 
GetIgnoredCounters()26 const std::unordered_set<std::string>* GetIgnoredCounters() {
27   static const std::unordered_set<std::string>* icounters =
28       CreateIgnoredCounters();
29   return icounters;
30 }
31 
ToCpuTensor(const at::Tensor & tensor)32 at::Tensor ToCpuTensor(const at::Tensor& tensor) {
33   // tensor.to() implicitly triggers a sync if t.device=torch::kLazy.
34   return tensor.to(torch::kCPU);
35 }
36 
CopyToDevice(const torch::Tensor & tensor,const torch::Device & device)37 torch::Tensor CopyToDevice(
38     const torch::Tensor& tensor,
39     const torch::Device& device) {
40   return tensor.clone().to(device, /*non_blocking=*/false, /*copy=*/true);
41 }
42 
EqualValues(at::Tensor tensor1,at::Tensor tensor2)43 bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) {
44   tensor1 = ToCpuTensor(tensor1);
45   tensor2 = ToCpuTensor(tensor2);
46   if (torch::isnan(tensor1).any().item<bool>()) {
47     EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2)));
48     tensor1.nan_to_num_();
49     tensor2.nan_to_num_();
50   }
51   if (tensor1.sizes() != tensor2.sizes() ||
52       tensor1.dtype() != tensor2.dtype()) {
53     std::cerr << "Different shape:\n"
54               << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
55               << tensor2.dtype() << " " << tensor2.sizes() << "\n";
56     return false;
57   }
58   at::ScalarType type1 = tensor1.scalar_type();
59   at::ScalarType type2 = tensor2.scalar_type();
60   if (type1 != type2) {
61     tensor1 = tensor1.toType(type2);
62   }
63   bool equal = tensor1.equal(tensor2);
64   return equal;
65 }
66 
EqualValuesNoElementTypeCheck(at::Tensor tensor1,at::Tensor tensor2)67 bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) {
68   tensor1 = ToCpuTensor(tensor1);
69   tensor2 = ToCpuTensor(tensor2);
70   if (tensor1.sizes() != tensor2.sizes()) {
71     std::cerr << "Different shape:\n"
72               << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
73               << tensor2.dtype() << " " << tensor2.sizes() << "\n";
74     return false;
75   }
76   at::ScalarType type1 = tensor1.scalar_type();
77   at::ScalarType type2 = tensor2.scalar_type();
78   if (type1 != type2) {
79     tensor1 = tensor1.toType(type2);
80   }
81   bool equal = tensor1.equal(tensor2);
82   return equal;
83 }
84 
ForEachDevice(const std::function<void (const torch::Device &)> & devfn)85 void ForEachDevice(const std::function<void(const torch::Device&)>& devfn) {
86   // Currently TorchScript backend only supports one type of hardware per
87   // process, which is set by env. And the ordinal is always 0 given distributed
88   // training/ multi-device is not supported yet.
89   auto device = torch::lazy::BackendDevice();
90   torch::Device torch_device = torch::lazy::backendDeviceToAtenDevice(device);
91   devfn(torch_device);
92 }
93 
CloseValues(at::Tensor tensor1,at::Tensor tensor2,double rtol,double atol)94 bool CloseValues(
95     at::Tensor tensor1,
96     at::Tensor tensor2,
97     double rtol,
98     double atol) {
99   tensor1 = ToCpuTensor(tensor1);
100   tensor2 = ToCpuTensor(tensor2);
101   if (torch::isnan(tensor1).any().item<bool>()) {
102     EXPECT_TRUE(EqualValues(torch::isnan(tensor1), torch::isnan(tensor2)));
103     tensor1.nan_to_num_();
104     tensor2.nan_to_num_();
105   }
106   if (tensor1.sizes() != tensor2.sizes() ||
107       tensor1.dtype() != tensor2.dtype()) {
108     std::cerr << "Different shape:\n"
109               << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
110               << tensor2.dtype() << " " << tensor2.sizes() << "\n";
111     return false;
112   }
113   bool equal = tensor1.allclose(tensor2, rtol, atol);
114   return equal;
115 }
116 
GetTensorTextGraph(at::Tensor tensor)117 std::string GetTensorTextGraph(at::Tensor tensor) {
118   torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
119   return torch::lazy::DumpUtil::ToText({lazy_tensor->GetIrValue().node.get()});
120 }
121 
GetTensorDotGraph(at::Tensor tensor)122 std::string GetTensorDotGraph(at::Tensor tensor) {
123   torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
124   return torch::lazy::DumpUtil::ToDot({lazy_tensor->GetIrValue().node.get()});
125 }
126 
TestBackward(const std::vector<torch::Tensor> & inputs,const torch::Device & device,const std::function<torch::Tensor (const std::vector<torch::Tensor> &)> & testfn,double rtol,double atol,int derivative_level)127 void TestBackward(
128     const std::vector<torch::Tensor>& inputs,
129     const torch::Device& device,
130     const std::function<torch::Tensor(const std::vector<torch::Tensor>&)>&
131         testfn,
132     double rtol,
133     double atol,
134     int derivative_level) {
135   std::vector<torch::Tensor> input_vars;
136   std::vector<torch::Tensor> xinput_vars;
137   std::vector<torch::Tensor> inputs_w_grad;
138   std::vector<torch::Tensor> xinputs_w_grad;
139   for (size_t i = 0; i < inputs.size(); ++i) {
140     const torch::Tensor& input = inputs[i];
141     if (input.defined()) {
142       torch::Tensor oinput =
143           input.clone().detach().set_requires_grad(input.requires_grad());
144       input_vars.push_back(oinput);
145 
146       torch::Tensor xinput = CopyToDevice(input, device)
147                                  .detach()
148                                  .set_requires_grad(input.requires_grad());
149       xinput_vars.push_back(xinput);
150       if (input.requires_grad()) {
151         inputs_w_grad.push_back(oinput);
152         xinputs_w_grad.push_back(xinput);
153       }
154     } else {
155       input_vars.emplace_back();
156       xinput_vars.emplace_back();
157     }
158   }
159 
160   torch::Tensor output = testfn(input_vars);
161   torch::Tensor xoutput = testfn(xinput_vars);
162   torch::lazy::AllClose(output, xoutput, rtol, atol);
163 
164   std::vector<torch::Tensor> outs = {output};
165   std::vector<torch::Tensor> xouts = {xoutput};
166   for (int d = 1; d <= derivative_level; ++d) {
167     // Check grad of sum(outs) w.r.t inputs_w_grad.
168     torch::Tensor sum = torch::zeros_like(outs[0]).sum();
169     torch::Tensor xsum = torch::zeros_like(xouts[0]).sum();
170     for (size_t i = 0; i < outs.size(); ++i) {
171       if (outs[i].requires_grad()) {
172         sum += outs[i].sum();
173         xsum += xouts[i].sum();
174       }
175     }
176     // Calculating higher order derivative requires create_graph=true
177     bool create_graph = d != derivative_level;
178     outs = torch::autograd::grad(
179         {sum},
180         inputs_w_grad,
181         /*grad_outputs=*/{},
182         /*retain_graph=*/std::nullopt,
183         /*create_graph=*/create_graph,
184         /*allow_unused=*/true);
185     xouts = torch::autograd::grad(
186         {xsum},
187         xinputs_w_grad,
188         /*grad_outputs=*/{},
189         /*retain_graph=*/std::nullopt,
190         /*create_graph=*/create_graph,
191         /*allow_unused=*/true);
192     for (size_t i = 0; i < outs.size(); ++i) {
193       ASSERT_EQ(outs[i].defined(), xouts[i].defined());
194       if (outs[i].defined()) {
195         AllClose(outs[i], xouts[i], rtol, atol);
196       }
197     }
198   }
199 }
200 
201 } // namespace lazy
202 } // namespace torch
203