xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <cstdint>
5 
6 namespace torch {
7 namespace distributed {
8 namespace autograd {
9 
10 // This structure represents autograd metadata that we need to pass across
11 // different nodes when we call an RPC which needs autograd computation.
12 struct TORCH_API AutogradMetadata {
13   AutogradMetadata(int64_t autogradContextId, int64_t autogradMessageId);
14 
15   // autogradContextId_ is a globally unique integer that identifies a
16   // particular distributed autograd pass.
17   int64_t autogradContextId;
18   // autogradMessageId_ is a globally unique integer that identifies a pair
19   // of send/recv autograd functions.
20   int64_t autogradMessageId;
21 };
22 
23 } // namespace autograd
24 } // namespace distributed
25 } // namespace torch
26