xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/WeightNorm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorUtils.h>
4 #include <ATen/TensorOperators.h>
5 #include <ATen/native/cpu/WeightNormKernel.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_weight_norm_differentiable_backward_native.h>
12 #include <ATen/ops/_weight_norm_interface.h>
13 #include <ATen/ops/_weight_norm_interface_backward_native.h>
14 #include <ATen/ops/_weight_norm_interface_native.h>
15 #include <ATen/ops/_weight_norm_native.h>
16 #include <ATen/ops/empty_strided.h>
17 #include <ATen/ops/norm_except_dim.h>
18 #include <ATen/ops/norm_except_dim_native.h>
19 #endif
20 
21 #include <vector>
22 
23 namespace at::native {
24 
25 DEFINE_DISPATCH(weight_norm_stub);
26 DEFINE_DISPATCH(weight_norm_backward_stub);
27 
28 // Staying faithful to the Python for now for clarity, look for optimizations later
29 // (e.g., single return statement for RVO)
norm_except_dim(const Tensor & v,int64_t pow,int64_t dim)30 Tensor norm_except_dim(const Tensor & v, int64_t pow, int64_t dim)
31 {
32   // I assume tensor.contiguous(), view(), norm(), etc. here will dispatch through VariableType.
33   if (dim == -1) {
34     return v.norm(pow);
35   } else if (dim == 0) {
36     std::vector<int64_t> output_size(v.dim(), 1);
37     output_size[0] = v.size(0);
38     return v.contiguous().view({v.size(0), -1}).norm(pow, 1).view(output_size);
39   } else if (dim == v.dim() - 1) {
40     std::vector<int64_t> output_size(v.dim(), 1);
41     output_size[v.dim() - 1] = v.size(v.dim() - 1);
42     return v.contiguous().view({-1, v.size(v.dim() - 1)}).norm(pow, 0).view(output_size);
43   } else {
44     // To consider: at::native::norm_except_dim is probably fine as well,
45     // and would avoid an additional dynamic dispatch.
46     return at::norm_except_dim(v.transpose(0, dim), pow, 0).transpose(0, dim); // optimize?
47   }
48 }
49 
weight_norm_cpu(const Tensor & v,const Tensor & g,int64_t dim)50 std::tuple<Tensor,Tensor> weight_norm_cpu(
51     const Tensor& v,
52     const Tensor& g,
53     int64_t dim) {
54   auto w = at::empty_like(v, at::MemoryFormat::Contiguous);
55 
56   // align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'
57   const auto dtype = g.scalar_type() == at::ScalarType::BFloat16 ?
58       at::ScalarType::Float : g.scalar_type();
59   auto norm = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(dtype));
60   weight_norm_stub(kCPU, w, norm, v, g, dim);
61 
62   return std::tuple<Tensor, Tensor>{w, norm};
63 }
64 
weight_norm_backward_cpu(const Tensor & grad_w,const Tensor & saved_v,const Tensor & saved_g,const Tensor & saved_norm,int64_t dim)65 std::tuple<Tensor, Tensor> weight_norm_backward_cpu(
66     const Tensor& grad_w,
67     const Tensor& saved_v,
68     const Tensor& saved_g,
69     const Tensor& saved_norm,
70     int64_t dim) {
71   TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
72   TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
73   TORCH_CHECK(saved_norm.is_contiguous(), "saved_norm must be contiguous");
74 
75   auto grad_v = at::empty_like(saved_v, at::MemoryFormat::Contiguous);
76   auto grad_g = at::empty_like(saved_g, at::MemoryFormat::Contiguous);
77   weight_norm_backward_stub(kCPU, grad_v, grad_g, grad_w, saved_v, saved_g, saved_norm, dim);
78 
79   return std::tuple<Tensor, Tensor>{grad_v, grad_g};
80 }
81 
_weight_norm(const Tensor & v_in,const Tensor & g_in,int64_t dim)82 Tensor _weight_norm
83   (const Tensor & v_in,
84    const Tensor & g_in,
85    int64_t dim)
86 {
87 
88   TORCH_CHECK(
89     v_in.device() == g_in.device(),
90     "weight_norm: expected v_in and g_in to be on the same device, but v_in is "
91     "on ", v_in.device(), " and g_in is on ", g_in.device());
92 
93   auto v = v_in.contiguous();
94   auto g = g_in.contiguous();
95 
96   auto has_half_dtype = v.scalar_type() == at::ScalarType::Half
97     || g.scalar_type() == at::ScalarType::Half;
98 
99   bool can_use_fused = !has_half_dtype && ((dim == 0) || (dim == v.dim() - 1));
100 
101   if (can_use_fused) {
102     // weight_norm does not have a derivative defined for it, so this will route back through
103     // VariableType.cpp, and construct a WeightNormFusedBackward object in the autograd graph.
104     return std::get<0>(at::_weight_norm_interface(v, g, dim));
105   } else {
106     // Double-differentiable primitive ops
107     // at::native::norm_except_dim would probably be fine as well.
108     return v*(g/at::norm_except_dim(v, 2, dim));
109   }
110 }
111 
112 // Differentiable backward path, an alternative to weight_norm_backward, to be used
113 // when backward is itself creating a graph.
114 // The GradMode::is_enabled() check must be performed within Functions.cpp; that's why we
115 // define a separate function here, instead of inlining it in weight_norm_cuda_backward.
_weight_norm_differentiable_backward(const Tensor & grad_w,const Tensor & saved_v,const Tensor & saved_g,const Tensor & saved_norms,int64_t dim)116 std::tuple<Tensor, Tensor> _weight_norm_differentiable_backward
117   (const Tensor & grad_w,
118    const Tensor & saved_v,
119    const Tensor & saved_g,
120    const Tensor & saved_norms,
121    int64_t dim)
122 {
123   // In Functions.cpp, the HardshrinkBackward object supplies "grad.contiguous()"
124   // as the first argument, so grad_w should be contiguous here.
125   // All these checks should succeed:
126   TORCH_CHECK(grad_w.is_contiguous(), "grad_w must be contiguous");
127   TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
128   TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
129   TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");
130 
131   int64_t last_dim = saved_v.dim() - 1;
132   int64_t last_size = saved_v.size(last_dim);
133 
134   // Like weight_norm_fused_backward, weight_norm_differentiable_backward should only ever be called
135   // through a WeightNormFusedBackward object, so we expect that dim == 0 || dim == saved_v.size(-1)
136   TORCH_CHECK(dim == 0 || dim == last_dim, "Expected dim to be the first or last dimension");
137 
138   // saved_g and saved_norms are already shaped to broadcast over the correct dimensions
139 
140   // ...but saved_norms might be Float when saved_g and saved_v are half.
141   // To consider:  saved_norms.to(..., True /*non_blocking*/);
142   auto norms = saved_norms.to(saved_g.scalar_type());
143 
144   std::vector<int64_t> bcast_size(saved_v.dim(), 1);
145 
146   // Analytic backward path using differentiable primitive ops
147   if (dim == 0) {
148     bcast_size[0] = saved_v.size(0);
149     auto per_dim_sums = (grad_w*saved_v).view({saved_v.size(0), -1}).sum(1).view(bcast_size);
150     auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
151     auto grad_g = per_dim_sums/norms;
152     return std::tuple<Tensor, Tensor>{grad_v, grad_g};
153   } else { // dim == last_dim
154     bcast_size[last_dim] = last_size;
155     auto per_dim_sums = (grad_w*saved_v).view({-1, last_size}).sum(0).view(bcast_size);
156     auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
157     auto grad_g = per_dim_sums/norms;
158     return std::tuple<Tensor, Tensor>{grad_v, grad_g};
159   }
160 }
161 
162 } // namespace at::native
163