xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/Utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/c10d/Utils.hpp>
2 
3 #include <cstring>
4 
5 namespace c10d {
6 
getTensorShapes(const std::vector<at::Tensor> & tensors)7 std::vector<at::Tensor> getTensorShapes(
8     const std::vector<at::Tensor>& tensors) {
9   std::vector<at::Tensor> shapeTensors;
10   shapeTensors.reserve(tensors.size());
11   for (const auto& tensor : tensors) {
12     // Use `at::tensor()` to copy the data underlying `sizes()` since it may be
13     // released elsewhere.
14     at::Tensor shapesTensor =
15         at::tensor(tensor.sizes(), at::TensorOptions().dtype(at::kLong));
16     shapeTensors.emplace_back(std::move(shapesTensor));
17   }
18   return shapeTensors;
19 }
20 
getTensorsNumel(const std::vector<at::Tensor> & tensors)21 size_t getTensorsNumel(const std::vector<at::Tensor>& tensors) {
22   size_t numel = 0;
23   for (auto& tensor : tensors) {
24     numel += tensor.numel();
25   }
26   return numel;
27 }
28 
29 } // namespace c10d
30