xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/ParamCommsUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
7 
8 namespace torch {
9 
ParamCommsDebugInfo(std::tuple<std::string,std::string> pgName,int rank,std::string && collName,int64_t inNelems,int64_t outNelems,at::ScalarType dType,std::vector<int64_t> inSplitSizes,std::vector<int64_t> outSplitSizes,int globalRankStart,int globalRankStride,int worldSize)10 ParamCommsDebugInfo::ParamCommsDebugInfo(
11     std::tuple<std::string, std::string> pgName,
12     int rank,
13     std::string&& collName,
14     int64_t inNelems,
15     int64_t outNelems,
16     at::ScalarType dType,
17     std::vector<int64_t> inSplitSizes,
18     std::vector<int64_t> outSplitSizes,
19     int globalRankStart,
20     int globalRankStride,
21     int worldSize)
22     : pgName_(std::move(pgName)),
23       rank_(rank),
24       worldSize_(worldSize),
25       collectiveName_(std::move(collName)),
26       inMessageNelems_(inNelems),
27       outMessageNelems_(outNelems),
28       dType_(dType),
29       inputSplitSizes_(std::move(inSplitSizes)),
30       outputSplitSizes_(std::move(outSplitSizes)),
31       globalRankStart_(globalRankStart),
32       globalRankStride_(globalRankStride) {
33   if (globalRankStride > 0) {
34     for (int i = 0; i < worldSize; i++) {
35       groupRanks_.push_back(globalRankStart + i * globalRankStride);
36     }
37   }
38 }
39 
40 } // namespace torch
41