xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/events/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import json
11from dataclasses import asdict, dataclass, field
12from enum import Enum
13from typing import Dict, Optional, Union
14
15
16__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"]
17
18EventMetadataValue = Union[str, int, float, bool, None]
19
20
21class EventSource(str, Enum):
22    """Known identifiers of the event producers."""
23
24    AGENT = "AGENT"
25    WORKER = "WORKER"
26
27
28@dataclass
29class Event:
30    """
31    The class represents the generic event that occurs during the torchelastic job execution.
32
33    The event can be any kind of meaningful action.
34
35    Args:
36        name: event name.
37        source: the event producer, e.g. agent or worker
38        timestamp: timestamp in milliseconds when event occurred.
39        metadata: additional data that is associated with the event.
40    """
41
42    name: str
43    source: EventSource
44    timestamp: int = 0
45    metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
46
47    def __str__(self):
48        return self.serialize()
49
50    @staticmethod
51    def deserialize(data: Union[str, "Event"]) -> "Event":
52        if isinstance(data, Event):
53            return data
54        if isinstance(data, str):
55            data_dict = json.loads(data)
56        data_dict["source"] = EventSource[data_dict["source"]]  # type: ignore[possibly-undefined]
57        return Event(**data_dict)
58
59    def serialize(self) -> str:
60        return json.dumps(asdict(self))
61
62
63class NodeState(str, Enum):
64    """The states that a node can be in rendezvous."""
65
66    INIT = "INIT"
67    RUNNING = "RUNNING"
68    SUCCEEDED = "SUCCEEDED"
69    FAILED = "FAILED"
70
71
72@dataclass
73class RdzvEvent:
74    """
75    Dataclass to represent any rendezvous event.
76
77    Args:
78        name: Event name. (E.g. Current action being performed)
79        run_id: The run id of the rendezvous
80        message: The message describing the event
81        hostname: Hostname of the node
82        pid: The process id of the node
83        node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED)
84        master_endpoint: The master endpoint for the rendezvous store, if known
85        rank: The rank of the node, if known
86        local_id: The local_id of the node, if defined in dynamic_rendezvous.py
87        error_trace: Error stack trace, if this is an error event.
88    """
89
90    name: str
91    run_id: str
92    message: str
93    hostname: str
94    pid: int
95    node_state: NodeState
96    master_endpoint: str = ""
97    rank: Optional[int] = None
98    local_id: Optional[int] = None
99    error_trace: str = ""
100
101    def __str__(self):
102        return self.serialize()
103
104    @staticmethod
105    def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent":
106        if isinstance(data, RdzvEvent):
107            return data
108        if isinstance(data, str):
109            data_dict = json.loads(data)
110        data_dict["node_state"] = NodeState[data_dict["node_state"]]  # type: ignore[possibly-undefined]
111        return RdzvEvent(**data_dict)
112
113    def serialize(self) -> str:
114        return json.dumps(asdict(self))
115