1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorUtils.h>
4 #include <ATen/TensorOperators.h>
5 #include <ATen/native/cpu/WeightNormKernel.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_weight_norm_differentiable_backward_native.h>
12 #include <ATen/ops/_weight_norm_interface.h>
13 #include <ATen/ops/_weight_norm_interface_backward_native.h>
14 #include <ATen/ops/_weight_norm_interface_native.h>
15 #include <ATen/ops/_weight_norm_native.h>
16 #include <ATen/ops/empty_strided.h>
17 #include <ATen/ops/norm_except_dim.h>
18 #include <ATen/ops/norm_except_dim_native.h>
19 #endif
20
21 #include <vector>
22
23 namespace at::native {
24
25 DEFINE_DISPATCH(weight_norm_stub);
26 DEFINE_DISPATCH(weight_norm_backward_stub);
27
28 // Staying faithful to the Python for now for clarity, look for optimizations later
29 // (e.g., single return statement for RVO)
norm_except_dim(const Tensor & v,int64_t pow,int64_t dim)30 Tensor norm_except_dim(const Tensor & v, int64_t pow, int64_t dim)
31 {
32 // I assume tensor.contiguous(), view(), norm(), etc. here will dispatch through VariableType.
33 if (dim == -1) {
34 return v.norm(pow);
35 } else if (dim == 0) {
36 std::vector<int64_t> output_size(v.dim(), 1);
37 output_size[0] = v.size(0);
38 return v.contiguous().view({v.size(0), -1}).norm(pow, 1).view(output_size);
39 } else if (dim == v.dim() - 1) {
40 std::vector<int64_t> output_size(v.dim(), 1);
41 output_size[v.dim() - 1] = v.size(v.dim() - 1);
42 return v.contiguous().view({-1, v.size(v.dim() - 1)}).norm(pow, 0).view(output_size);
43 } else {
44 // To consider: at::native::norm_except_dim is probably fine as well,
45 // and would avoid an additional dynamic dispatch.
46 return at::norm_except_dim(v.transpose(0, dim), pow, 0).transpose(0, dim); // optimize?
47 }
48 }
49
weight_norm_cpu(const Tensor & v,const Tensor & g,int64_t dim)50 std::tuple<Tensor,Tensor> weight_norm_cpu(
51 const Tensor& v,
52 const Tensor& g,
53 int64_t dim) {
54 auto w = at::empty_like(v, at::MemoryFormat::Contiguous);
55
56 // align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'
57 const auto dtype = g.scalar_type() == at::ScalarType::BFloat16 ?
58 at::ScalarType::Float : g.scalar_type();
59 auto norm = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(dtype));
60 weight_norm_stub(kCPU, w, norm, v, g, dim);
61
62 return std::tuple<Tensor, Tensor>{w, norm};
63 }
64
weight_norm_backward_cpu(const Tensor & grad_w,const Tensor & saved_v,const Tensor & saved_g,const Tensor & saved_norm,int64_t dim)65 std::tuple<Tensor, Tensor> weight_norm_backward_cpu(
66 const Tensor& grad_w,
67 const Tensor& saved_v,
68 const Tensor& saved_g,
69 const Tensor& saved_norm,
70 int64_t dim) {
71 TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
72 TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
73 TORCH_CHECK(saved_norm.is_contiguous(), "saved_norm must be contiguous");
74
75 auto grad_v = at::empty_like(saved_v, at::MemoryFormat::Contiguous);
76 auto grad_g = at::empty_like(saved_g, at::MemoryFormat::Contiguous);
77 weight_norm_backward_stub(kCPU, grad_v, grad_g, grad_w, saved_v, saved_g, saved_norm, dim);
78
79 return std::tuple<Tensor, Tensor>{grad_v, grad_g};
80 }
81
_weight_norm(const Tensor & v_in,const Tensor & g_in,int64_t dim)82 Tensor _weight_norm
83 (const Tensor & v_in,
84 const Tensor & g_in,
85 int64_t dim)
86 {
87
88 TORCH_CHECK(
89 v_in.device() == g_in.device(),
90 "weight_norm: expected v_in and g_in to be on the same device, but v_in is "
91 "on ", v_in.device(), " and g_in is on ", g_in.device());
92
93 auto v = v_in.contiguous();
94 auto g = g_in.contiguous();
95
96 auto has_half_dtype = v.scalar_type() == at::ScalarType::Half
97 || g.scalar_type() == at::ScalarType::Half;
98
99 bool can_use_fused = !has_half_dtype && ((dim == 0) || (dim == v.dim() - 1));
100
101 if (can_use_fused) {
102 // weight_norm does not have a derivative defined for it, so this will route back through
103 // VariableType.cpp, and construct a WeightNormFusedBackward object in the autograd graph.
104 return std::get<0>(at::_weight_norm_interface(v, g, dim));
105 } else {
106 // Double-differentiable primitive ops
107 // at::native::norm_except_dim would probably be fine as well.
108 return v*(g/at::norm_except_dim(v, 2, dim));
109 }
110 }
111
112 // Differentiable backward path, an alternative to weight_norm_backward, to be used
113 // when backward is itself creating a graph.
114 // The GradMode::is_enabled() check must be performed within Functions.cpp; that's why we
115 // define a separate function here, instead of inlining it in weight_norm_cuda_backward.
_weight_norm_differentiable_backward(const Tensor & grad_w,const Tensor & saved_v,const Tensor & saved_g,const Tensor & saved_norms,int64_t dim)116 std::tuple<Tensor, Tensor> _weight_norm_differentiable_backward
117 (const Tensor & grad_w,
118 const Tensor & saved_v,
119 const Tensor & saved_g,
120 const Tensor & saved_norms,
121 int64_t dim)
122 {
123 // In Functions.cpp, the HardshrinkBackward object supplies "grad.contiguous()"
124 // as the first argument, so grad_w should be contiguous here.
125 // All these checks should succeed:
126 TORCH_CHECK(grad_w.is_contiguous(), "grad_w must be contiguous");
127 TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous");
128 TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous");
129 TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous");
130
131 int64_t last_dim = saved_v.dim() - 1;
132 int64_t last_size = saved_v.size(last_dim);
133
134 // Like weight_norm_fused_backward, weight_norm_differentiable_backward should only ever be called
135 // through a WeightNormFusedBackward object, so we expect that dim == 0 || dim == saved_v.size(-1)
136 TORCH_CHECK(dim == 0 || dim == last_dim, "Expected dim to be the first or last dimension");
137
138 // saved_g and saved_norms are already shaped to broadcast over the correct dimensions
139
140 // ...but saved_norms might be Float when saved_g and saved_v are half.
141 // To consider: saved_norms.to(..., True /*non_blocking*/);
142 auto norms = saved_norms.to(saved_g.scalar_type());
143
144 std::vector<int64_t> bcast_size(saved_v.dim(), 1);
145
146 // Analytic backward path using differentiable primitive ops
147 if (dim == 0) {
148 bcast_size[0] = saved_v.size(0);
149 auto per_dim_sums = (grad_w*saved_v).view({saved_v.size(0), -1}).sum(1).view(bcast_size);
150 auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
151 auto grad_g = per_dim_sums/norms;
152 return std::tuple<Tensor, Tensor>{grad_v, grad_g};
153 } else { // dim == last_dim
154 bcast_size[last_dim] = last_size;
155 auto per_dim_sums = (grad_w*saved_v).view({-1, last_size}).sum(0).view(bcast_size);
156 auto grad_v = (saved_g/norms)*(grad_w - saved_v*(per_dim_sums/(norms*norms)));
157 auto grad_g = per_dim_sums/norms;
158 return std::tuple<Tensor, Tensor>{grad_v, grad_g};
159 }
160 }
161
162 } // namespace at::native
163