1# mypy: allow-untyped-defs 2from typing import Any, Dict, List, Optional 3 4from torch.utils.data.datapipes._decorator import functional_datapipe 5from torch.utils.data.datapipes.dataframe.structures import DataChunkDF 6from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe 7 8 9# TODO(VitalyFedyunin): Add error when two different traces get combined 10 11__all__ = [ 12 "Capture", 13 "CaptureA", 14 "CaptureAdd", 15 "CaptureCall", 16 "CaptureControl", 17 "CaptureDataFrame", 18 "CaptureDataFrameWithDataPipeOps", 19 "CaptureF", 20 "CaptureGetAttr", 21 "CaptureGetItem", 22 "CaptureInitial", 23 "CaptureLikeMock", 24 "CaptureMul", 25 "CaptureSetItem", 26 "CaptureSub", 27 "CaptureVariable", 28 "CaptureVariableAssign", 29 "DataFrameTracer", 30 "DataFrameTracedOps", 31 "disable_capture", 32 "get_val", 33] 34 35 36def disable_capture(): 37 CaptureControl.disabled = True 38 39 40class CaptureControl: 41 disabled = False 42 43 44class DataFrameTracedOps(DFIterDataPipe): 45 def __init__(self, source_datapipe, output_var): 46 self.source_datapipe = source_datapipe 47 self.output_var = output_var 48 49 def __iter__(self): 50 for item in self.source_datapipe: 51 yield self.output_var.apply_ops(item) 52 53 54# TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions 55DATAPIPES_OPS = [ 56 "_dataframes_as_tuples", 57 "groupby", 58 "_dataframes_filter", 59 "map", 60 "to_datapipe", 61 "shuffle", 62 "concat", 63 "batch", 64 "_dataframes_per_row", 65 "_dataframes_concat", 66 "_dataframes_shuffle", 67] 68 69UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sharding"] 70 71 72class Capture: 73 # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures 74 75 def __init__(self, schema_df=None): 76 self.ctx = {"operations": [], "variables": [], "schema_df": schema_df} 77 78 def __str__(self): 79 return self._ops_str() 80 81 def _ops_str(self): 82 res = "" 83 for op in self.ctx["operations"]: 84 if len(res) > 0: 85 res += "\n" 86 res += str(op) 87 return res 88 89 def __getstate__(self): 90 # TODO(VitalyFedyunin): Currently can't pickle (why?) 91 self.ctx["schema_df"] = None 92 for var in self.ctx["variables"]: 93 var.calculated_value = None 94 state = {} 95 for item in self.__dict__: 96 state[item] = getattr(self, item) 97 return state 98 99 def __setstate__(self, state): 100 for k, v in state.items(): 101 setattr(self, k, v) 102 103 def __getattr__(self, attrname): 104 if attrname == "kwarg" or attrname == "kwargs": 105 raise RuntimeError("no kwargs!") 106 if attrname in ["__deepcopy__"]: 107 raise AttributeError 108 result = CaptureGetAttr(self, attrname, ctx=self.ctx) 109 return result 110 111 def __getitem__(self, key): 112 return CaptureGetItem(self, key, ctx=self.ctx) 113 114 def __setitem__(self, key, value): 115 self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx)) 116 117 def __add__(self, add_val): 118 res = CaptureAdd(self, add_val, ctx=self.ctx) 119 var = CaptureVariable(res, ctx=self.ctx) 120 self.ctx["operations"].append( 121 CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) 122 ) 123 return var 124 125 def __sub__(self, add_val): 126 res = CaptureSub(self, add_val, ctx=self.ctx) 127 var = CaptureVariable(res, ctx=self.ctx) 128 self.ctx["operations"].append( 129 CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) 130 ) 131 return var 132 133 def __mul__(self, add_val): 134 res = CaptureMul(self, add_val, ctx=self.ctx) 135 var = CaptureVariable(res, ctx=self.ctx) 136 t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) 137 self.ctx["operations"].append(t) 138 return var 139 140 def _is_context_empty(self): 141 return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0 142 143 def apply_ops_2(self, dataframe): 144 # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) 145 self.ctx["variables"][0].calculated_value = dataframe 146 for op in self.ctx["operations"]: 147 op.execute() 148 149 @property 150 def columns(self): 151 self.apply_ops_2(self.ctx["schema_df"]) 152 value = self.execute() 153 return value.columns 154 155 # TODO(VitalyFedyunin): Add tests 156 # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture 157 158 def __call__(self, *args, **kwargs): 159 # TODO: Check if args or kwargs have more than one different context 160 if self._is_context_empty(): 161 # TODO: Allow CaptureA to take context from mock 162 for arg in args: 163 if isinstance(arg, Capture) and not arg._is_context_empty(): 164 self.ctx = arg.ctx 165 break 166 if self._is_context_empty(): 167 for k, v in kwargs.items(): 168 if isinstance(k, Capture) and not k._is_context_empty(): 169 self.ctx = k.ctx 170 break 171 if isinstance(v, Capture) and not v._is_context_empty(): 172 self.ctx = v.ctx 173 break 174 175 res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs) 176 var = CaptureVariable(None, ctx=self.ctx) 177 t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res) 178 self.ctx["operations"].append(t) 179 return var 180 181 182class CaptureF(Capture): 183 def __init__(self, ctx=None, **kwargs): 184 if ctx is None: 185 self.ctx = {"operations": [], "variables": []} 186 else: 187 self.ctx = ctx 188 self.kwargs = kwargs 189 190 191class CaptureA(CaptureF): 192 def __str__(self): 193 return f"{self.kwargs['name']}" 194 195 def execute(self): 196 value = self.kwargs["real_attribute"] 197 return value 198 199 200class CaptureLikeMock: 201 def __init__(self, name): 202 import unittest.mock as mock 203 204 # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead. 205 get_target, attribute = mock._get_target(name) # type: ignore[attr-defined] 206 self.get_target = get_target 207 self.attribute = attribute 208 self.name = name 209 210 def __enter__(self): 211 self.save = getattr(self.get_target(), self.attribute) 212 capt = CaptureA(name=self.name, real_attribute=self.save) 213 setattr(self.get_target(), self.attribute, capt) 214 215 def __exit__(self, *exc_info): 216 setattr(self.get_target(), self.attribute, self.save) 217 218 219class CaptureCall(Capture): 220 def __init__(self, callable, ctx=None, **kwargs): 221 if ctx is None: 222 self.ctx = {"operations": [], "variables": []} 223 else: 224 self.ctx = ctx 225 self.kwargs = kwargs 226 self.callable = callable 227 228 def __str__(self): 229 return "{callable}({args},{kwargs})".format( 230 callable=self.callable, **self.kwargs 231 ) 232 233 def execute(self): 234 # TODO: VitalyFedyunin execute kwargs and maybe nested structures 235 executed_args = [] 236 for arg in self.kwargs["args"]: 237 if isinstance(arg, Capture): 238 executed_args.append(arg.execute()) 239 else: 240 executed_args.append(arg) 241 left = get_val(self.callable) 242 return left(*executed_args, **self.kwargs["kwargs"]) 243 244 245class CaptureVariableAssign(CaptureF): 246 def __str__(self): 247 variable = self.kwargs["variable"] 248 value = self.kwargs["value"] 249 return f"{variable} = {value}" 250 251 def execute(self): 252 self.kwargs["variable"].calculated_value = self.kwargs["value"].execute() 253 254 255class CaptureVariable(Capture): 256 # TODO(VitalyFedyunin): This should be atomic and thread safe 257 names_idx = 0 258 259 def __init__(self, value, ctx): 260 if CaptureControl.disabled: 261 raise RuntimeError("Attempting to create capture variable with capture off") 262 self.ctx = ctx 263 self.value = value 264 self.name = f"var_{CaptureVariable.names_idx}" 265 CaptureVariable.names_idx += 1 266 self.ctx["variables"].append(self) 267 268 def __str__(self): 269 return self.name 270 271 def execute(self): 272 return self.calculated_value 273 274 def apply_ops(self, dataframe): 275 # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) 276 self.ctx["variables"][0].calculated_value = dataframe 277 for op in self.ctx["operations"]: 278 op.execute() 279 return self.calculated_value 280 281 282class CaptureGetItem(Capture): 283 def __init__(self, left, key, ctx): 284 self.ctx = ctx 285 self.left = left 286 self.key = key 287 288 def __str__(self): 289 return f"{self.left}[{get_val(self.key)}]" 290 291 def execute(self): 292 left = self.left.execute() 293 return left[self.key] 294 295 296class CaptureSetItem(Capture): 297 def __init__(self, left, key, value, ctx): 298 self.ctx = ctx 299 self.left = left 300 self.key = key 301 self.value = value 302 303 def __str__(self): 304 return f"{self.left}[{get_val(self.key)}] = {self.value}" 305 306 def execute(self): 307 left = self.left.execute() 308 value = self.value.execute() 309 left[self.key] = value 310 311 312class CaptureAdd(Capture): 313 def __init__(self, left, right, ctx): 314 self.ctx = ctx 315 self.left = left 316 self.right = right 317 318 def __str__(self): 319 return f"{self.left} + {self.right}" 320 321 def execute(self): 322 return get_val(self.left) + get_val(self.right) 323 324 325class CaptureMul(Capture): 326 def __init__(self, left, right, ctx): 327 self.ctx = ctx 328 self.left = left 329 self.right = right 330 331 def __str__(self): 332 return f"{self.left} * {self.right}" 333 334 def execute(self): 335 return get_val(self.left) * get_val(self.right) 336 337 338class CaptureSub(Capture): 339 def __init__(self, left, right, ctx): 340 self.ctx = ctx 341 self.left = left 342 self.right = right 343 344 def __str__(self): 345 return f"{self.left} - {self.right}" 346 347 def execute(self): 348 return get_val(self.left) - get_val(self.right) 349 350 351class CaptureGetAttr(Capture): 352 def __init__(self, src, name, ctx): 353 self.ctx = ctx 354 self.src = src 355 self.name = name 356 357 def __str__(self): 358 return f"{self.src}.{self.name}" 359 360 def execute(self): 361 val = get_val(self.src) 362 return getattr(val, self.name) 363 364 365def get_val(capture): 366 if isinstance(capture, Capture): 367 return capture.execute() 368 elif isinstance(capture, str): 369 return f'"{capture}"' 370 else: 371 return capture 372 373 374class CaptureInitial(CaptureVariable): 375 def __init__(self, schema_df=None): 376 new_ctx: Dict[str, List[Any]] = { 377 "operations": [], 378 "variables": [], 379 "schema_df": schema_df, 380 } 381 super().__init__(None, new_ctx) 382 self.name = f"input_{self.name}" 383 384 385class CaptureDataFrame(CaptureInitial): 386 pass 387 388 389class CaptureDataFrameWithDataPipeOps(CaptureDataFrame): 390 def as_datapipe(self): 391 return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self) 392 393 def raw_iterator(self): 394 return self.as_datapipe().__iter__() 395 396 def __iter__(self): 397 return iter(self._dataframes_as_tuples()) 398 399 def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF): 400 dp = self._dataframes_per_row()._dataframes_concat(batch_size) 401 dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class) 402 dp._dp_contains_dataframe = True 403 return dp 404 405 def groupby( 406 self, 407 group_key_fn, 408 *, 409 buffer_size=10000, 410 group_size=None, 411 guaranteed_group_size=None, 412 drop_remaining=False, 413 ): 414 dp = self._dataframes_per_row() 415 dp = dp.as_datapipe().groupby( 416 group_key_fn, 417 buffer_size=buffer_size, 418 group_size=group_size, 419 guaranteed_group_size=guaranteed_group_size, 420 drop_remaining=drop_remaining, 421 ) 422 return dp 423 424 def shuffle(self, *args, **kwargs): 425 return self._dataframes_shuffle(*args, **kwargs) 426 427 def filter(self, *args, **kwargs): 428 return self._dataframes_filter(*args, **kwargs) 429 430 def collate(self, *args, **kwargs): 431 raise RuntimeError("Can't collate unbatched DataFrames stream") 432 433 def __getattr__(self, attrname): # ? 434 if attrname in UNIMPLEMENTED_ATTR: 435 raise AttributeError("Attempting to get ", attrname) 436 if attrname in DATAPIPES_OPS: 437 return (self.as_datapipe()).__getattr__(attrname) 438 return super().__getattr__(attrname) 439 440 441@functional_datapipe("trace_as_dataframe") 442class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc] 443 source_datapipe: Optional[Any] = None 444 445 # TODO(VitalyFedyunin): Must implement all special functions of datapipes 446 447 def set_shuffle_settings(self, *args, **kwargs): 448 pass 449 450 def is_shardable(self): 451 return False 452 453 def __init__(self, source_datapipe, schema_df=None): 454 self.source_datapipe = source_datapipe 455 if schema_df is None: 456 schema_df = next(iter(self.source_datapipe)) 457 super().__init__(schema_df=schema_df) 458