xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/RowwisePrune.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <c10/util/irange.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_rowwise_prune_native.h>
13 #include <ATen/ops/empty.h>
14 #endif
15 
16 namespace at::native {
17 
18 namespace {
19 
20 template <typename input_t>
_rowwise_prune_helper(const Tensor & weights,const Tensor & mask,ScalarType compressed_indices_dtype)21 std::tuple<Tensor, Tensor> _rowwise_prune_helper(
22       const Tensor& weights, const Tensor& mask,
23       ScalarType compressed_indices_dtype) {
24   int num_non_masked_rows = 0;
25   auto mask_contig = mask.contiguous();
26   auto mask_data = mask_contig.data_ptr<bool>();
27   for (const auto i : c10::irange(mask.numel())) {
28     num_non_masked_rows += (((mask_data[i] == true)) ? 1 : 0);
29   }
30   int num_cols = weights.size(1);
31   auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols},
32       weights.options());
33   auto compressed_indices_mapping = at::empty({mask.numel()},
34       compressed_indices_dtype);
35   AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half,
36                              at::ScalarType::BFloat16,
37                              weights.scalar_type(),
38                             "rowwise_prune_helper", [&]() {
39     auto* pruned_2d_tensor_data = pruned_2d_tensor.data_ptr<scalar_t>();
40     auto compressed_indices_mapping_data =
41         compressed_indices_mapping.data_ptr<input_t>();
42     auto weights_data = weights.data_ptr<scalar_t>();
43     int last_row_kept = 0;
44     for (const auto i : c10::irange(mask.numel())) {
45       if (mask_data[i]) {
46         memcpy(pruned_2d_tensor_data + last_row_kept * num_cols,
47               weights_data + i * num_cols,
48               num_cols * sizeof (scalar_t));
49         compressed_indices_mapping_data[i] = last_row_kept;
50         last_row_kept++;
51       } else {
52         compressed_indices_mapping_data[i] = -1;
53       }
54     }
55   });
56   return std::tuple<Tensor, Tensor>(pruned_2d_tensor,
57       compressed_indices_mapping);
58 }
59 
60 } // namespace
61 
62 
63 // This operator introduces sparsity to the 'weights' matrix with the help
64 // of the importance indicator 'mask'.
65 //
66 // A row is considered important and not pruned if the mask value for that
67 // particular row is 1(True) and not important otherwise.
68 //
69 // This operator doesn't zero out the pruned rows in-place. Instead, it
70 // returns a tuple that contains a pruned weights tensor as well as a map that
71 // can be used to look up the original row in the pruned weights tensor.
72 // We refer this map as 'compressed indices map' going forward.
73 
74 // The 'compressed indices map' is an 1D tensor that contains one entry per
75 // original row in 'weights'. The array index is the index for the original
76 // non-pruned weight tensor and the value would be the re-mapped index in the
77 // pruned weights tensor. If the value for a index is -1, it means the
78 // corresponding row has been pruned from the original weight tensor.
79 
80 // Arguments:
81 // 'weights' - two dimensional matrix that needs to be prune.
82 // 'mask' - 1D boolean tensor that represents whether a row is important or
83 //    not. A mask value of 1 means the row should be kept and 0 means the row
84 //    should be pruned.
85 //
86 // Returns:
87 // A tuple containing two tensors,
88 // 1. A pruned weight tensor that contains only the weights that are preserved
89 //    post pruning.
90 // 2. An 1D tensor that contains the mapping between original weight row and
91 //    the corresponding row in the pruned weights tensor.
_rowwise_prune(const Tensor & weights,const Tensor & mask,ScalarType compressed_indices_dtype)92 std::tuple<Tensor, Tensor> _rowwise_prune(const Tensor& weights,
93                                           const Tensor& mask,
94                                           ScalarType compressed_indices_dtype) {
95   TORCH_CHECK(weights.ndimension() == 2,
96       "'weights' should have 2 dimensions.");
97   TORCH_CHECK(
98     mask.numel() == weights.size(0),
99     "Number of elements in 'mask' should be equivalent to the "
100     "number of rows in 'weights'."
101   )
102   TORCH_CHECK(
103       compressed_indices_dtype == ScalarType::Int ||
104       compressed_indices_dtype == ScalarType::Long,
105       "compressed_indices_dtype should be either int(int32) or long(int64).");
106 
107   if (compressed_indices_dtype == at::ScalarType::Int) {
108     return _rowwise_prune_helper<int32_t>(weights, mask,
109                                           compressed_indices_dtype);
110   }
111   return _rowwise_prune_helper<int64_t>(weights, mask,
112                                         compressed_indices_dtype);
113 }
114 
115 } // namespace at::native
116