1# mypy: allow-untyped-defs
2from collections import defaultdict
3from copy import deepcopy
4from typing import Any, Dict, Optional, TYPE_CHECKING
5
6import pytorch_lightning as pl  # type: ignore[import]
7
8from ._data_sparstity_utils import (
9    _attach_model_to_data_sparsifier,
10    _get_valid_name,
11    _log_sparsified_level,
12)
13
14
15if TYPE_CHECKING:
16    import torch
17
18
19class PostTrainingDataSparsity(pl.callbacks.Callback):
20    """Lightning callback that enables post-training sparsity.
21
22    This callback aims to sparsify the model inside lightning module after training.
23    **Note that the model is copied and then sparsified, so the existing model is not modified**
24
25    The sparsified model can be used for comparison and can be accessed using
26        <callback_obj>.sparsified
27
28    Args:
29        data_sparsifier_class (some implemented class of BaseDataSparsifier)
30            The data sparsifier object of this class is created when the
31            training starts.
32            Note: Objects should not be passed in here as they are created
33            once the training completes.
34
35        data_sparsifier_args (Dict)
36            Dictionary of args to be passed to the data sparsifier.
37            Note: data_list arg should be ignored
38
39    Hooks implemented:
40        on_fit_end()
41            1. copies the model and attaches it to the sparsifier
42            2. sparsier step() is called
43            3. squashes the mask()
44    """
45
46    def __init__(self, data_sparsifier_class, data_sparsifier_args):
47        super().__init__()
48        self.data_sparsifier_class = data_sparsifier_class
49        self.data_sparsifier_args = data_sparsifier_args
50        self.data_sparsifier: Any = None
51        self.sparsified: Optional[torch.nn.Module] = None
52
53    def on_fit_end(self, trainer, pl_module) -> None:
54        self.sparsified = deepcopy(pl_module.model).eval()
55        self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args)
56
57        _attach_model_to_data_sparsifier(self.sparsified, self.data_sparsifier)
58
59        self.data_sparsifier.step()
60
61        self.data_sparsifier.squash_mask()  # currently squashes params for all mask
62
63        _log_sparsified_level(self.sparsified, self.data_sparsifier)
64
65
66class TrainingAwareDataSparsity(pl.callbacks.Callback):
67    """Lightning callback that enables in-training sparsity.
68
69    This callback aims to sparsify the model inside lightning module during training.
70    **Note that the model is copied and then sparsified, so the existing model is not modified**
71
72    The sparsified model can be used for comparison and can be accessed using
73        <callback_obj>.sparsified
74
75    Args:
76        data_sparsifier_class (some implemented class of BaseDataSparsifier)
77            The data sparsifier object of this class is created when the
78            training starts.
79            Note: Objects should not be passed in here as they are created
80            when the training starts.
81
82        data_sparsifier_args (Dict)
83            Dictionary of args to be passed to the data sparsifier.
84            Note: data_list arg should be ignored
85
86        data_scheduler_class (some implemented class of BaseDataScheduler)
87            The data scheduler of this class is created when the training starts
88            Note: Objects should not be passed in here as they are created
89            when the training starts.
90
91        data_scheduler_args(Dict)
92            Dictionary of args to be passed to the data scheduler.
93            **Note: data_sparsifier arg should be ignored as the recipe
94            creates and pass sparsifier object into the class**
95
96    Hooks implemented:
97        on_train_start()
98            Data sparsifier and scheduler objects are created.
99            Pytorch model attached to the sparsifier
100
101        on_train_epoch_start()
102            Loads the state_dict of the data sparsifier
103
104        on_train_epoch_end()
105            1. Copies the model and attaches it to the sparsifier
106            2. sparsifier step() and scheduler step()
107            3. Dump state_dict of the current sparsifier
108
109        on_train_end()
110            squash mask
111    """
112
113    def __init__(
114        self,
115        data_sparsifier_class,
116        data_sparsifier_args,
117        data_scheduler_class,
118        data_scheduler_args,
119    ):
120        super().__init__()
121        # data sparsifier objects
122        self.data_sparsifier_class = data_sparsifier_class
123        self.data_sparsifier_args = data_sparsifier_args
124
125        # scheduler objects
126        self.data_scheduler_class = data_scheduler_class
127        self.data_scheduler_args = data_scheduler_args
128
129        # fields
130        self.data_sparsifier: Any = None
131        self.data_scheduler: Any = None
132        self.sparsified: Optional[torch.nn.Module] = None
133
134        self.data_sparsifier_state_dict: Any = None
135
136    def on_train_start(self, trainer, pl_module) -> None:
137        # create sparsifier
138        self.data_sparsifier = self.data_sparsifier_class(**self.data_sparsifier_args)
139        self.sparsified = deepcopy(pl_module.model)
140
141        _attach_model_to_data_sparsifier(
142            self.sparsified, self.data_sparsifier
143        )  # just to populate the base_sl in the scheduler
144
145        # create scheduler
146        args = deepcopy(self.data_scheduler_args)
147        args["data_sparsifier"] = self.data_sparsifier
148        self.data_scheduler = self.data_scheduler_class(**args)
149
150    def on_train_epoch_start(self, trainer, pl_module):
151        if self.data_sparsifier_state_dict is None:
152            return  # probably first epoch
153
154        # load the existing config for each data
155        self.data_sparsifier.load_state_dict(self.data_sparsifier_state_dict)
156
157    def __create_config_based_on_state(self, pl_module):
158        config: Dict = defaultdict()
159        if self.data_sparsifier_state_dict is None:
160            return config
161        for name, _ in pl_module.model.named_parameters():
162            valid_name = _get_valid_name(name)
163            config[valid_name] = self.data_sparsifier.data_groups[valid_name]
164
165        return config
166
167    def on_train_epoch_end(self, trainer, pl_module):
168        self.sparsified = deepcopy(pl_module.model)
169        config = self.__create_config_based_on_state(pl_module)
170
171        # attach model to the data sparsifier
172        _attach_model_to_data_sparsifier(
173            self.sparsified, self.data_sparsifier, config=config
174        )
175        self.data_sparsifier.step()
176        self.data_scheduler.step()
177
178        self.data_sparsifier_state_dict = self.data_sparsifier.state_dict()
179
180    def on_train_end(self, trainer, pl_module):
181        self.data_sparsifier.squash_mask()
182