1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
7
8 namespace torch {
9
ParamCommsDebugInfo(std::tuple<std::string,std::string> pgName,int rank,std::string && collName,int64_t inNelems,int64_t outNelems,at::ScalarType dType,std::vector<int64_t> inSplitSizes,std::vector<int64_t> outSplitSizes,int globalRankStart,int globalRankStride,int worldSize)10 ParamCommsDebugInfo::ParamCommsDebugInfo(
11 std::tuple<std::string, std::string> pgName,
12 int rank,
13 std::string&& collName,
14 int64_t inNelems,
15 int64_t outNelems,
16 at::ScalarType dType,
17 std::vector<int64_t> inSplitSizes,
18 std::vector<int64_t> outSplitSizes,
19 int globalRankStart,
20 int globalRankStride,
21 int worldSize)
22 : pgName_(std::move(pgName)),
23 rank_(rank),
24 worldSize_(worldSize),
25 collectiveName_(std::move(collName)),
26 inMessageNelems_(inNelems),
27 outMessageNelems_(outNelems),
28 dType_(dType),
29 inputSplitSizes_(std::move(inSplitSizes)),
30 outputSplitSizes_(std::move(outSplitSizes)),
31 globalRankStart_(globalRankStart),
32 globalRankStride_(globalRankStride) {
33 if (globalRankStride > 0) {
34 for (int i = 0; i < worldSize; i++) {
35 groupRanks_.push_back(globalRankStart + i * globalRankStride);
36 }
37 }
38 }
39
40 } // namespace torch
41