1 #include <torch/csrc/utils/tensor_flatten.h>
2
3 #include <map>
4 #include <unordered_map>
5
6 namespace torch::utils {
7
8 using namespace at;
9
take_tensors(TensorList tensors,size_t size_limit,bool fine_grained)10 std::vector<TensorGroup> take_tensors(
11 TensorList tensors,
12 size_t size_limit,
13 bool fine_grained) {
14 std::vector<TensorGroup> results;
15 // an overapproximation, but at least we won't have to copy stuff around
16 results.reserve(tensors.size());
17 std::map<int64_t, TensorGroup> groups;
18 size_t cur_group_size = 0;
19
20 for (const auto& tensor : tensors) {
21 size_t tensor_size = 0;
22 if (tensor.is_sparse()) {
23 const auto& indices = tensor._indices();
24 const auto& values = tensor._values();
25 tensor_size = indices.numel() * indices.element_size() +
26 values.numel() * indices.element_size();
27 } else {
28 tensor_size = tensor.numel() * tensor.element_size();
29 }
30
31 auto& type_group = groups[static_cast<int64_t>(type_id(tensor))];
32 type_group.tensors.push_back(tensor);
33
34 if (fine_grained) {
35 cur_group_size += tensor_size;
36 // Regardless the type, the current total size exceeds the limit
37 if (cur_group_size >= size_limit) {
38 // Spill all types to separate groups in results
39 for (auto& entry : groups) {
40 auto& group = entry.second;
41 results.emplace_back(std::move(group));
42 }
43 cur_group_size = 0;
44 groups.clear();
45 }
46 } else {
47 type_group.size += tensor_size;
48 if (type_group.size >= size_limit) {
49 results.emplace_back();
50 std::swap(results.back(), type_group);
51 }
52 }
53 }
54 // End case. Look for any remaining groups and return them.
55 for (auto& entry : groups) {
56 auto& group = entry.second;
57 if (group.tensors.empty()) {
58 continue;
59 }
60 results.emplace_back(std::move(group));
61 }
62 return results;
63 }
64
reorder_tensors_like(std::vector<Tensor> & tensors,TensorList order)65 void reorder_tensors_like(std::vector<Tensor>& tensors, TensorList order) {
66 AT_ASSERT(tensors.size() == order.size());
67 std::unordered_map<size_t, std::vector<size_t>> type_id_to_indices;
68 for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
69 type_id_to_indices[type_id(tensors[i])].push_back(i);
70
71 std::unordered_map<size_t, size_t> type_id_to_type_used;
72 std::vector<Tensor> ordered_tensors;
73 ordered_tensors.reserve(tensors.size());
74 for (auto& tmpl_tensor : order) {
75 size_t tmpl_type_id = type_id(tmpl_tensor);
76 auto& indices = type_id_to_indices[tmpl_type_id];
77 auto& used = type_id_to_type_used[tmpl_type_id];
78 ordered_tensors.push_back(tensors[indices[used++]]);
79 }
80 std::swap(tensors, ordered_tensors);
81 }
82
83 namespace {
84
get_indices(const at::Tensor & t)85 at::Tensor get_indices(const at::Tensor& t) {
86 return t._indices();
87 }
88
get_values(const at::Tensor & t)89 at::Tensor get_values(const at::Tensor& t) {
90 return t._values();
91 }
92
93 } // namespace
94
flatten_sparse_tensors(at::TensorList tensors)95 std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
96 at::TensorList tensors) {
97 auto flat_indices = utils::flatten_dense_tensors(fmap(tensors, &get_indices));
98 auto flat_values = utils::flatten_dense_tensors(fmap(tensors, &get_values));
99 return std::make_pair(flat_indices, flat_values);
100 }
101
unflatten_sparse_tensors(const at::Tensor & flat_indices,const at::Tensor & flat_values,at::TensorList tensors)102 std::vector<at::Tensor> unflatten_sparse_tensors(
103 const at::Tensor& flat_indices,
104 const at::Tensor& flat_values,
105 at::TensorList tensors) {
106 if (tensors.empty())
107 return {};
108
109 auto indices =
110 utils::unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices));
111 auto values =
112 utils::unflatten_dense_tensors(flat_values, fmap(tensors, &get_values));
113
114 std::vector<at::Tensor> outputs;
115 outputs.reserve(tensors.size());
116 for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i) {
117 auto& ref_t = tensors[i];
118 auto t =
119 at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes());
120 outputs.emplace_back(t._coalesced_(ref_t.is_coalesced()));
121 }
122 return outputs;
123 }
124
125 } // namespace torch::utils
126