xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_flatten.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/tensor_flatten.h>
2 
3 #include <map>
4 #include <unordered_map>
5 
6 namespace torch::utils {
7 
8 using namespace at;
9 
take_tensors(TensorList tensors,size_t size_limit,bool fine_grained)10 std::vector<TensorGroup> take_tensors(
11     TensorList tensors,
12     size_t size_limit,
13     bool fine_grained) {
14   std::vector<TensorGroup> results;
15   // an overapproximation, but at least we won't have to copy stuff around
16   results.reserve(tensors.size());
17   std::map<int64_t, TensorGroup> groups;
18   size_t cur_group_size = 0;
19 
20   for (const auto& tensor : tensors) {
21     size_t tensor_size = 0;
22     if (tensor.is_sparse()) {
23       const auto& indices = tensor._indices();
24       const auto& values = tensor._values();
25       tensor_size = indices.numel() * indices.element_size() +
26           values.numel() * indices.element_size();
27     } else {
28       tensor_size = tensor.numel() * tensor.element_size();
29     }
30 
31     auto& type_group = groups[static_cast<int64_t>(type_id(tensor))];
32     type_group.tensors.push_back(tensor);
33 
34     if (fine_grained) {
35       cur_group_size += tensor_size;
36       // Regardless the type, the current total size exceeds the limit
37       if (cur_group_size >= size_limit) {
38         // Spill all types to separate groups in results
39         for (auto& entry : groups) {
40           auto& group = entry.second;
41           results.emplace_back(std::move(group));
42         }
43         cur_group_size = 0;
44         groups.clear();
45       }
46     } else {
47       type_group.size += tensor_size;
48       if (type_group.size >= size_limit) {
49         results.emplace_back();
50         std::swap(results.back(), type_group);
51       }
52     }
53   }
54   // End case. Look for any remaining groups and return them.
55   for (auto& entry : groups) {
56     auto& group = entry.second;
57     if (group.tensors.empty()) {
58       continue;
59     }
60     results.emplace_back(std::move(group));
61   }
62   return results;
63 }
64 
reorder_tensors_like(std::vector<Tensor> & tensors,TensorList order)65 void reorder_tensors_like(std::vector<Tensor>& tensors, TensorList order) {
66   AT_ASSERT(tensors.size() == order.size());
67   std::unordered_map<size_t, std::vector<size_t>> type_id_to_indices;
68   for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
69     type_id_to_indices[type_id(tensors[i])].push_back(i);
70 
71   std::unordered_map<size_t, size_t> type_id_to_type_used;
72   std::vector<Tensor> ordered_tensors;
73   ordered_tensors.reserve(tensors.size());
74   for (auto& tmpl_tensor : order) {
75     size_t tmpl_type_id = type_id(tmpl_tensor);
76     auto& indices = type_id_to_indices[tmpl_type_id];
77     auto& used = type_id_to_type_used[tmpl_type_id];
78     ordered_tensors.push_back(tensors[indices[used++]]);
79   }
80   std::swap(tensors, ordered_tensors);
81 }
82 
83 namespace {
84 
get_indices(const at::Tensor & t)85 at::Tensor get_indices(const at::Tensor& t) {
86   return t._indices();
87 }
88 
get_values(const at::Tensor & t)89 at::Tensor get_values(const at::Tensor& t) {
90   return t._values();
91 }
92 
93 } // namespace
94 
flatten_sparse_tensors(at::TensorList tensors)95 std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
96     at::TensorList tensors) {
97   auto flat_indices = utils::flatten_dense_tensors(fmap(tensors, &get_indices));
98   auto flat_values = utils::flatten_dense_tensors(fmap(tensors, &get_values));
99   return std::make_pair(flat_indices, flat_values);
100 }
101 
unflatten_sparse_tensors(const at::Tensor & flat_indices,const at::Tensor & flat_values,at::TensorList tensors)102 std::vector<at::Tensor> unflatten_sparse_tensors(
103     const at::Tensor& flat_indices,
104     const at::Tensor& flat_values,
105     at::TensorList tensors) {
106   if (tensors.empty())
107     return {};
108 
109   auto indices =
110       utils::unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices));
111   auto values =
112       utils::unflatten_dense_tensors(flat_values, fmap(tensors, &get_values));
113 
114   std::vector<at::Tensor> outputs;
115   outputs.reserve(tensors.size());
116   for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i) {
117     auto& ref_t = tensors[i];
118     auto t =
119         at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes());
120     outputs.emplace_back(t._coalesced_(ref_t.is_coalesced()));
121   }
122   return outputs;
123 }
124 
125 } // namespace torch::utils
126