xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/dataframe/dataframes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, List, Optional
3
4from torch.utils.data.datapipes._decorator import functional_datapipe
5from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
6from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe
7
8
9# TODO(VitalyFedyunin): Add error when two different traces get combined
10
11__all__ = [
12    "Capture",
13    "CaptureA",
14    "CaptureAdd",
15    "CaptureCall",
16    "CaptureControl",
17    "CaptureDataFrame",
18    "CaptureDataFrameWithDataPipeOps",
19    "CaptureF",
20    "CaptureGetAttr",
21    "CaptureGetItem",
22    "CaptureInitial",
23    "CaptureLikeMock",
24    "CaptureMul",
25    "CaptureSetItem",
26    "CaptureSub",
27    "CaptureVariable",
28    "CaptureVariableAssign",
29    "DataFrameTracer",
30    "DataFrameTracedOps",
31    "disable_capture",
32    "get_val",
33]
34
35
36def disable_capture():
37    CaptureControl.disabled = True
38
39
40class CaptureControl:
41    disabled = False
42
43
44class DataFrameTracedOps(DFIterDataPipe):
45    def __init__(self, source_datapipe, output_var):
46        self.source_datapipe = source_datapipe
47        self.output_var = output_var
48
49    def __iter__(self):
50        for item in self.source_datapipe:
51            yield self.output_var.apply_ops(item)
52
53
54#  TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
55DATAPIPES_OPS = [
56    "_dataframes_as_tuples",
57    "groupby",
58    "_dataframes_filter",
59    "map",
60    "to_datapipe",
61    "shuffle",
62    "concat",
63    "batch",
64    "_dataframes_per_row",
65    "_dataframes_concat",
66    "_dataframes_shuffle",
67]
68
69UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sharding"]
70
71
72class Capture:
73    # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures
74
75    def __init__(self, schema_df=None):
76        self.ctx = {"operations": [], "variables": [], "schema_df": schema_df}
77
78    def __str__(self):
79        return self._ops_str()
80
81    def _ops_str(self):
82        res = ""
83        for op in self.ctx["operations"]:
84            if len(res) > 0:
85                res += "\n"
86            res += str(op)
87        return res
88
89    def __getstate__(self):
90        # TODO(VitalyFedyunin): Currently can't pickle (why?)
91        self.ctx["schema_df"] = None
92        for var in self.ctx["variables"]:
93            var.calculated_value = None
94        state = {}
95        for item in self.__dict__:
96            state[item] = getattr(self, item)
97        return state
98
99    def __setstate__(self, state):
100        for k, v in state.items():
101            setattr(self, k, v)
102
103    def __getattr__(self, attrname):
104        if attrname == "kwarg" or attrname == "kwargs":
105            raise RuntimeError("no kwargs!")
106        if attrname in ["__deepcopy__"]:
107            raise AttributeError
108        result = CaptureGetAttr(self, attrname, ctx=self.ctx)
109        return result
110
111    def __getitem__(self, key):
112        return CaptureGetItem(self, key, ctx=self.ctx)
113
114    def __setitem__(self, key, value):
115        self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
116
117    def __add__(self, add_val):
118        res = CaptureAdd(self, add_val, ctx=self.ctx)
119        var = CaptureVariable(res, ctx=self.ctx)
120        self.ctx["operations"].append(
121            CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
122        )
123        return var
124
125    def __sub__(self, add_val):
126        res = CaptureSub(self, add_val, ctx=self.ctx)
127        var = CaptureVariable(res, ctx=self.ctx)
128        self.ctx["operations"].append(
129            CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
130        )
131        return var
132
133    def __mul__(self, add_val):
134        res = CaptureMul(self, add_val, ctx=self.ctx)
135        var = CaptureVariable(res, ctx=self.ctx)
136        t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
137        self.ctx["operations"].append(t)
138        return var
139
140    def _is_context_empty(self):
141        return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
142
143    def apply_ops_2(self, dataframe):
144        # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
145        self.ctx["variables"][0].calculated_value = dataframe
146        for op in self.ctx["operations"]:
147            op.execute()
148
149    @property
150    def columns(self):
151        self.apply_ops_2(self.ctx["schema_df"])
152        value = self.execute()
153        return value.columns
154
155    # TODO(VitalyFedyunin): Add tests
156    # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture
157
158    def __call__(self, *args, **kwargs):
159        # TODO: Check if args or kwargs have more than one different context
160        if self._is_context_empty():
161            # TODO: Allow CaptureA to take context from mock
162            for arg in args:
163                if isinstance(arg, Capture) and not arg._is_context_empty():
164                    self.ctx = arg.ctx
165                    break
166            if self._is_context_empty():
167                for k, v in kwargs.items():
168                    if isinstance(k, Capture) and not k._is_context_empty():
169                        self.ctx = k.ctx
170                        break
171                    if isinstance(v, Capture) and not v._is_context_empty():
172                        self.ctx = v.ctx
173                        break
174
175        res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
176        var = CaptureVariable(None, ctx=self.ctx)
177        t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
178        self.ctx["operations"].append(t)
179        return var
180
181
182class CaptureF(Capture):
183    def __init__(self, ctx=None, **kwargs):
184        if ctx is None:
185            self.ctx = {"operations": [], "variables": []}
186        else:
187            self.ctx = ctx
188        self.kwargs = kwargs
189
190
191class CaptureA(CaptureF):
192    def __str__(self):
193        return f"{self.kwargs['name']}"
194
195    def execute(self):
196        value = self.kwargs["real_attribute"]
197        return value
198
199
200class CaptureLikeMock:
201    def __init__(self, name):
202        import unittest.mock as mock
203
204        # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
205        get_target, attribute = mock._get_target(name)  # type: ignore[attr-defined]
206        self.get_target = get_target
207        self.attribute = attribute
208        self.name = name
209
210    def __enter__(self):
211        self.save = getattr(self.get_target(), self.attribute)
212        capt = CaptureA(name=self.name, real_attribute=self.save)
213        setattr(self.get_target(), self.attribute, capt)
214
215    def __exit__(self, *exc_info):
216        setattr(self.get_target(), self.attribute, self.save)
217
218
219class CaptureCall(Capture):
220    def __init__(self, callable, ctx=None, **kwargs):
221        if ctx is None:
222            self.ctx = {"operations": [], "variables": []}
223        else:
224            self.ctx = ctx
225        self.kwargs = kwargs
226        self.callable = callable
227
228    def __str__(self):
229        return "{callable}({args},{kwargs})".format(
230            callable=self.callable, **self.kwargs
231        )
232
233    def execute(self):
234        # TODO: VitalyFedyunin execute kwargs and maybe nested structures
235        executed_args = []
236        for arg in self.kwargs["args"]:
237            if isinstance(arg, Capture):
238                executed_args.append(arg.execute())
239            else:
240                executed_args.append(arg)
241        left = get_val(self.callable)
242        return left(*executed_args, **self.kwargs["kwargs"])
243
244
245class CaptureVariableAssign(CaptureF):
246    def __str__(self):
247        variable = self.kwargs["variable"]
248        value = self.kwargs["value"]
249        return f"{variable} = {value}"
250
251    def execute(self):
252        self.kwargs["variable"].calculated_value = self.kwargs["value"].execute()
253
254
255class CaptureVariable(Capture):
256    # TODO(VitalyFedyunin): This should be atomic and thread safe
257    names_idx = 0
258
259    def __init__(self, value, ctx):
260        if CaptureControl.disabled:
261            raise RuntimeError("Attempting to create capture variable with capture off")
262        self.ctx = ctx
263        self.value = value
264        self.name = f"var_{CaptureVariable.names_idx}"
265        CaptureVariable.names_idx += 1
266        self.ctx["variables"].append(self)
267
268    def __str__(self):
269        return self.name
270
271    def execute(self):
272        return self.calculated_value
273
274    def apply_ops(self, dataframe):
275        # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
276        self.ctx["variables"][0].calculated_value = dataframe
277        for op in self.ctx["operations"]:
278            op.execute()
279        return self.calculated_value
280
281
282class CaptureGetItem(Capture):
283    def __init__(self, left, key, ctx):
284        self.ctx = ctx
285        self.left = left
286        self.key = key
287
288    def __str__(self):
289        return f"{self.left}[{get_val(self.key)}]"
290
291    def execute(self):
292        left = self.left.execute()
293        return left[self.key]
294
295
296class CaptureSetItem(Capture):
297    def __init__(self, left, key, value, ctx):
298        self.ctx = ctx
299        self.left = left
300        self.key = key
301        self.value = value
302
303    def __str__(self):
304        return f"{self.left}[{get_val(self.key)}] = {self.value}"
305
306    def execute(self):
307        left = self.left.execute()
308        value = self.value.execute()
309        left[self.key] = value
310
311
312class CaptureAdd(Capture):
313    def __init__(self, left, right, ctx):
314        self.ctx = ctx
315        self.left = left
316        self.right = right
317
318    def __str__(self):
319        return f"{self.left} + {self.right}"
320
321    def execute(self):
322        return get_val(self.left) + get_val(self.right)
323
324
325class CaptureMul(Capture):
326    def __init__(self, left, right, ctx):
327        self.ctx = ctx
328        self.left = left
329        self.right = right
330
331    def __str__(self):
332        return f"{self.left} * {self.right}"
333
334    def execute(self):
335        return get_val(self.left) * get_val(self.right)
336
337
338class CaptureSub(Capture):
339    def __init__(self, left, right, ctx):
340        self.ctx = ctx
341        self.left = left
342        self.right = right
343
344    def __str__(self):
345        return f"{self.left} - {self.right}"
346
347    def execute(self):
348        return get_val(self.left) - get_val(self.right)
349
350
351class CaptureGetAttr(Capture):
352    def __init__(self, src, name, ctx):
353        self.ctx = ctx
354        self.src = src
355        self.name = name
356
357    def __str__(self):
358        return f"{self.src}.{self.name}"
359
360    def execute(self):
361        val = get_val(self.src)
362        return getattr(val, self.name)
363
364
365def get_val(capture):
366    if isinstance(capture, Capture):
367        return capture.execute()
368    elif isinstance(capture, str):
369        return f'"{capture}"'
370    else:
371        return capture
372
373
374class CaptureInitial(CaptureVariable):
375    def __init__(self, schema_df=None):
376        new_ctx: Dict[str, List[Any]] = {
377            "operations": [],
378            "variables": [],
379            "schema_df": schema_df,
380        }
381        super().__init__(None, new_ctx)
382        self.name = f"input_{self.name}"
383
384
385class CaptureDataFrame(CaptureInitial):
386    pass
387
388
389class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
390    def as_datapipe(self):
391        return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
392
393    def raw_iterator(self):
394        return self.as_datapipe().__iter__()
395
396    def __iter__(self):
397        return iter(self._dataframes_as_tuples())
398
399    def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF):
400        dp = self._dataframes_per_row()._dataframes_concat(batch_size)
401        dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class)
402        dp._dp_contains_dataframe = True
403        return dp
404
405    def groupby(
406        self,
407        group_key_fn,
408        *,
409        buffer_size=10000,
410        group_size=None,
411        guaranteed_group_size=None,
412        drop_remaining=False,
413    ):
414        dp = self._dataframes_per_row()
415        dp = dp.as_datapipe().groupby(
416            group_key_fn,
417            buffer_size=buffer_size,
418            group_size=group_size,
419            guaranteed_group_size=guaranteed_group_size,
420            drop_remaining=drop_remaining,
421        )
422        return dp
423
424    def shuffle(self, *args, **kwargs):
425        return self._dataframes_shuffle(*args, **kwargs)
426
427    def filter(self, *args, **kwargs):
428        return self._dataframes_filter(*args, **kwargs)
429
430    def collate(self, *args, **kwargs):
431        raise RuntimeError("Can't collate unbatched DataFrames stream")
432
433    def __getattr__(self, attrname):  # ?
434        if attrname in UNIMPLEMENTED_ATTR:
435            raise AttributeError("Attempting to get ", attrname)
436        if attrname in DATAPIPES_OPS:
437            return (self.as_datapipe()).__getattr__(attrname)
438        return super().__getattr__(attrname)
439
440
441@functional_datapipe("trace_as_dataframe")
442class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe):  # type: ignore[misc]
443    source_datapipe: Optional[Any] = None
444
445    # TODO(VitalyFedyunin): Must implement all special functions of datapipes
446
447    def set_shuffle_settings(self, *args, **kwargs):
448        pass
449
450    def is_shardable(self):
451        return False
452
453    def __init__(self, source_datapipe, schema_df=None):
454        self.source_datapipe = source_datapipe
455        if schema_df is None:
456            schema_df = next(iter(self.source_datapipe))
457        super().__init__(schema_df=schema_df)
458