1# mypy: allow-untyped-defs 2import contextlib 3import dataclasses 4import math 5import textwrap 6from typing import Any, Dict, Optional 7 8import torch 9from torch import inf 10 11 12@dataclasses.dataclass 13class __PrinterOptions: 14 precision: int = 4 15 threshold: float = 1000 16 edgeitems: int = 3 17 linewidth: int = 80 18 sci_mode: Optional[bool] = None 19 20 21PRINT_OPTS = __PrinterOptions() 22 23 24# We could use **kwargs, but this will give better docs 25def set_printoptions( 26 precision=None, 27 threshold=None, 28 edgeitems=None, 29 linewidth=None, 30 profile=None, 31 sci_mode=None, 32): 33 r"""Set options for printing. Items shamelessly taken from NumPy 34 35 Args: 36 precision: Number of digits of precision for floating point output 37 (default = 4). 38 threshold: Total number of array elements which trigger summarization 39 rather than full `repr` (default = 1000). 40 edgeitems: Number of array items in summary at beginning and end of 41 each dimension (default = 3). 42 linewidth: The number of characters per line for the purpose of 43 inserting line breaks (default = 80). Thresholded matrices will 44 ignore this parameter. 45 profile: Sane defaults for pretty printing. Can override with any of 46 the above options. (any one of `default`, `short`, `full`) 47 sci_mode: Enable (True) or disable (False) scientific notation. If 48 None (default) is specified, the value is defined by 49 `torch._tensor_str._Formatter`. This value is automatically chosen 50 by the framework. 51 52 Example:: 53 54 >>> # Limit the precision of elements 55 >>> torch.set_printoptions(precision=2) 56 >>> torch.tensor([1.12345]) 57 tensor([1.12]) 58 >>> # Limit the number of elements shown 59 >>> torch.set_printoptions(threshold=5) 60 >>> torch.arange(10) 61 tensor([0, 1, 2, ..., 7, 8, 9]) 62 >>> # Restore defaults 63 >>> torch.set_printoptions(profile='default') 64 >>> torch.tensor([1.12345]) 65 tensor([1.1235]) 66 >>> torch.arange(10) 67 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 68 69 """ 70 if profile is not None: 71 if profile == "default": 72 PRINT_OPTS.precision = 4 73 PRINT_OPTS.threshold = 1000 74 PRINT_OPTS.edgeitems = 3 75 PRINT_OPTS.linewidth = 80 76 elif profile == "short": 77 PRINT_OPTS.precision = 2 78 PRINT_OPTS.threshold = 1000 79 PRINT_OPTS.edgeitems = 2 80 PRINT_OPTS.linewidth = 80 81 elif profile == "full": 82 PRINT_OPTS.precision = 4 83 PRINT_OPTS.threshold = inf 84 PRINT_OPTS.edgeitems = 3 85 PRINT_OPTS.linewidth = 80 86 87 if precision is not None: 88 PRINT_OPTS.precision = precision 89 if threshold is not None: 90 PRINT_OPTS.threshold = threshold 91 if edgeitems is not None: 92 PRINT_OPTS.edgeitems = edgeitems 93 if linewidth is not None: 94 PRINT_OPTS.linewidth = linewidth 95 PRINT_OPTS.sci_mode = sci_mode 96 97 98def get_printoptions() -> Dict[str, Any]: 99 r"""Gets the current options for printing, as a dictionary that 100 can be passed as ``**kwargs`` to set_printoptions(). 101 """ 102 return dataclasses.asdict(PRINT_OPTS) 103 104 105@contextlib.contextmanager 106def printoptions(**kwargs): 107 r"""Context manager that temporarily changes the print options. Accepted 108 arguments are same as :func:`set_printoptions`.""" 109 old_kwargs = get_printoptions() 110 set_printoptions(**kwargs) 111 try: 112 yield 113 finally: 114 set_printoptions(**old_kwargs) 115 116 117def tensor_totype(t): 118 dtype = ( 119 torch.float 120 if ( 121 t.is_mps 122 or (t.is_xpu and not torch.xpu.get_device_properties(t.device).has_fp64) 123 ) 124 else torch.double 125 ) 126 return t.to(dtype=dtype) 127 128 129class _Formatter: 130 def __init__(self, tensor): 131 self.floating_dtype = tensor.dtype.is_floating_point 132 self.int_mode = True 133 self.sci_mode = False 134 self.max_width = 1 135 136 with torch.no_grad(): 137 tensor_view = tensor.reshape(-1) 138 139 if not self.floating_dtype: 140 for value in tensor_view: 141 value_str = f"{value}" 142 self.max_width = max(self.max_width, len(value_str)) 143 144 else: 145 nonzero_finite_vals = torch.masked_select( 146 tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0) 147 ) 148 149 if nonzero_finite_vals.numel() == 0: 150 # no valid number, do nothing 151 return 152 153 # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. 154 nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs()) 155 nonzero_finite_min = tensor_totype(nonzero_finite_abs.min()) 156 nonzero_finite_max = tensor_totype(nonzero_finite_abs.max()) 157 158 for value in nonzero_finite_vals: 159 if value != torch.ceil(value): 160 self.int_mode = False 161 break 162 163 if self.int_mode: 164 # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites 165 # to indicate that the tensor is of floating type. add 1 to the len to account for this. 166 if ( 167 nonzero_finite_max / nonzero_finite_min > 1000.0 168 or nonzero_finite_max > 1.0e8 169 ): 170 self.sci_mode = True 171 for value in nonzero_finite_vals: 172 value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) 173 self.max_width = max(self.max_width, len(value_str)) 174 else: 175 for value in nonzero_finite_vals: 176 value_str = f"{value:.0f}" 177 self.max_width = max(self.max_width, len(value_str) + 1) 178 else: 179 # Check if scientific representation should be used. 180 if ( 181 nonzero_finite_max / nonzero_finite_min > 1000.0 182 or nonzero_finite_max > 1.0e8 183 or nonzero_finite_min < 1.0e-4 184 ): 185 self.sci_mode = True 186 for value in nonzero_finite_vals: 187 value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value) 188 self.max_width = max(self.max_width, len(value_str)) 189 else: 190 for value in nonzero_finite_vals: 191 value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value) 192 self.max_width = max(self.max_width, len(value_str)) 193 194 if PRINT_OPTS.sci_mode is not None: 195 self.sci_mode = PRINT_OPTS.sci_mode 196 197 def width(self): 198 return self.max_width 199 200 def format(self, value): 201 if self.floating_dtype: 202 if self.sci_mode: 203 ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value) 204 elif self.int_mode: 205 ret = f"{value:.0f}" 206 if not (math.isinf(value) or math.isnan(value)): 207 ret += "." 208 else: 209 ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value) 210 else: 211 ret = f"{value}" 212 return (self.max_width - len(ret)) * " " + ret 213 214 215def _scalar_str(self, formatter1, formatter2=None): 216 if formatter2 is not None: 217 real_str = _scalar_str(self.real, formatter1) 218 imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip() 219 # handles negative numbers, +0.0, -0.0 220 if imag_str[0] == "+" or imag_str[0] == "-": 221 return real_str + imag_str 222 else: 223 return real_str + "+" + imag_str 224 else: 225 return formatter1.format(self.item()) 226 227 228def _vector_str(self, indent, summarize, formatter1, formatter2=None): 229 # length includes spaces and comma between elements 230 element_length = formatter1.width() + 2 231 if formatter2 is not None: 232 # width for imag_formatter + an extra j for complex 233 element_length += formatter2.width() + 1 234 235 elements_per_line = max( 236 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length))) 237 ) 238 239 def _val_formatter(val, formatter1=formatter1, formatter2=formatter2): 240 if formatter2 is not None: 241 real_str = formatter1.format(val.real) 242 imag_str = (formatter2.format(val.imag) + "j").lstrip() 243 # handles negative numbers, +0.0, -0.0 244 if imag_str[0] == "+" or imag_str[0] == "-": 245 return real_str + imag_str 246 else: 247 return real_str + "+" + imag_str 248 else: 249 return formatter1.format(val) 250 251 if summarize and not PRINT_OPTS.edgeitems: 252 # Deal with edge case that negative zero is zero 253 data = ["..."] 254 elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: 255 data = ( 256 [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()] 257 + [" ..."] 258 + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()] 259 ) 260 else: 261 data = [_val_formatter(val) for val in self.tolist()] 262 263 data_lines = [ 264 data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line) 265 ] 266 lines = [", ".join(line) for line in data_lines] 267 return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]" 268 269 270# formatter2 is only used for printing complex tensors. 271# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real 272# and tensor.imag respesectively 273def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None): 274 dim = self.dim() 275 276 if dim == 0: 277 return _scalar_str(self, formatter1, formatter2) 278 279 if dim == 1: 280 return _vector_str(self, indent, summarize, formatter1, formatter2) 281 282 if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: 283 slices = ( 284 [ 285 _tensor_str_with_formatter( 286 self[i], indent + 1, summarize, formatter1, formatter2 287 ) 288 for i in range(0, PRINT_OPTS.edgeitems) 289 ] 290 + ["..."] 291 + [ 292 _tensor_str_with_formatter( 293 self[i], indent + 1, summarize, formatter1, formatter2 294 ) 295 for i in range(len(self) - PRINT_OPTS.edgeitems, len(self)) 296 ] 297 ) 298 else: 299 slices = [ 300 _tensor_str_with_formatter( 301 self[i], indent + 1, summarize, formatter1, formatter2 302 ) 303 for i in range(0, self.size(0)) 304 ] 305 306 tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices) 307 return "[" + tensor_str + "]" 308 309 310def _tensor_str(self, indent): 311 if self.numel() == 0: 312 return "[]" 313 314 if self.has_names(): 315 # There are two main codepaths (possibly more) that tensor printing goes through: 316 # - tensor data can fit comfortably on screen 317 # - tensor data needs to be summarized 318 # Some of the codepaths don't fully support named tensors, so we send in 319 # an unnamed tensor to the formatting code as a workaround. 320 self = self.rename(None) 321 322 summarize = self.numel() > PRINT_OPTS.threshold 323 324 if self._is_zerotensor(): 325 self = self.clone() 326 327 # handle the negative bit 328 if self.is_neg(): 329 self = self.resolve_neg() 330 331 if self.dtype in [ 332 torch.float16, 333 torch.bfloat16, 334 torch.float8_e5m2, 335 torch.float8_e5m2fnuz, 336 torch.float8_e4m3fn, 337 torch.float8_e4m3fnuz, 338 ]: 339 self = self.float() 340 341 if self.dtype is torch.complex32: 342 self = self.cfloat() 343 344 if self.dtype.is_complex: 345 # handle the conjugate bit 346 self = self.resolve_conj() 347 real_formatter = _Formatter( 348 get_summarized_data(self.real) if summarize else self.real 349 ) 350 imag_formatter = _Formatter( 351 get_summarized_data(self.imag) if summarize else self.imag 352 ) 353 return _tensor_str_with_formatter( 354 self, indent, summarize, real_formatter, imag_formatter 355 ) 356 else: 357 formatter = _Formatter(get_summarized_data(self) if summarize else self) 358 return _tensor_str_with_formatter(self, indent, summarize, formatter) 359 360 361def _add_suffixes(tensor_str, suffixes, indent, force_newline): 362 tensor_strs = [tensor_str] 363 last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1 364 for suffix in suffixes: 365 suffix_len = len(suffix) 366 if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth: 367 tensor_strs.append(",\n" + " " * indent + suffix) 368 last_line_len = indent + suffix_len 369 force_newline = False 370 else: 371 tensor_strs.append(", " + suffix) 372 last_line_len += suffix_len + 2 373 tensor_strs.append(")") 374 return "".join(tensor_strs) 375 376 377def get_summarized_data(self): 378 dim = self.dim() 379 if dim == 0: 380 return self 381 if dim == 1: 382 if self.size(0) > 2 * PRINT_OPTS.edgeitems: 383 return torch.cat( 384 (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :]) 385 ) 386 else: 387 return self 388 if not PRINT_OPTS.edgeitems: 389 return self.new_empty([0] * self.dim()) 390 elif self.size(0) > 2 * PRINT_OPTS.edgeitems: 391 start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)] 392 end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))] 393 return torch.stack([get_summarized_data(x) for x in (start + end)]) 394 else: 395 return torch.stack([get_summarized_data(x) for x in self]) 396 397 398def _str_intern(inp, *, tensor_contents=None): 399 if torch._C._functorch.is_functorch_wrapped_tensor(inp): 400 return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents) 401 is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter 402 if inp.is_nested: 403 prefix = "nested_tensor(" 404 elif is_plain_tensor: 405 prefix = "tensor(" 406 else: 407 prefix = f"{type(inp).__name__}(" 408 indent = len(prefix) 409 suffixes = [] 410 custom_contents_provided = tensor_contents is not None 411 if custom_contents_provided: 412 tensor_str = tensor_contents 413 414 # This is used to extract the primal value and thus disable the forward AD 415 # within this function. 416 # TODO(albanD) This needs to be updated when more than one level is supported 417 self, tangent = torch.autograd.forward_ad.unpack_dual(inp) 418 419 # Note [Print tensor device]: 420 # A general logic here is we only print device when it doesn't match 421 # the device specified in default tensor type. 422 # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus 423 # torch._C._get_default_device() only returns either cpu or cuda. 424 # In other cases, we don't have a way to set them as default yet, 425 # and we should always print out device for them. 426 if ( 427 self.device.type != torch._C._get_default_device() 428 or ( 429 self.device.type == "cuda" 430 and torch.cuda.current_device() != self.device.index 431 ) 432 or (self.device.type == "mps") 433 ): 434 suffixes.append("device='" + str(self.device) + "'") 435 436 # Tensor printing performs tensor operations like slice, indexing, etc to make it in a 437 # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence, 438 # to avoid compilations, copying the tensor to cpu before printing. 439 if self.device.type in ["xla", "lazy", "ipu", "mtia"]: 440 self = self.to("cpu") 441 442 # TODO: add an API to map real -> complex dtypes 443 _default_complex_dtype = ( 444 torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat 445 ) 446 has_default_dtype = self.dtype in ( 447 torch.get_default_dtype(), 448 _default_complex_dtype, 449 torch.int64, 450 torch.bool, 451 ) 452 if self.is_sparse: 453 suffixes.append("size=" + str(tuple(self.shape))) 454 from torch._subclasses.fake_tensor import FakeTensor 455 456 is_meta = self.is_meta or isinstance(self, FakeTensor) 457 if not is_meta: 458 suffixes.append("nnz=" + str(self._nnz())) 459 if not has_default_dtype: 460 suffixes.append("dtype=" + str(self.dtype)) 461 if not custom_contents_provided: 462 indices_prefix = "indices=tensor(" 463 indices = self._indices().detach() 464 if is_meta: 465 indices_str = "..." 466 else: 467 indices_str = _tensor_str(indices, indent + len(indices_prefix)) 468 if is_meta or indices.numel() == 0: 469 indices_str += ", size=" + str(tuple(indices.shape)) 470 values_prefix = "values=tensor(" 471 values = self._values().detach() 472 if is_meta: 473 values_str = "..." 474 else: 475 values_str = _tensor_str(values, indent + len(values_prefix)) 476 if is_meta or values.numel() == 0: 477 values_str += ", size=" + str(tuple(values.shape)) 478 tensor_str = ( 479 indices_prefix 480 + indices_str 481 + "),\n" 482 + " " * indent 483 + values_prefix 484 + values_str 485 + ")" 486 ) 487 elif self.layout in { 488 torch.sparse_csr, 489 torch.sparse_csc, 490 torch.sparse_bsr, 491 torch.sparse_bsc, 492 }: 493 from torch._subclasses.fake_tensor import FakeTensor 494 495 suffixes.append("size=" + str(tuple(self.shape))) 496 is_meta = self.is_meta or isinstance(self, FakeTensor) 497 if not is_meta: 498 suffixes.append("nnz=" + str(self._nnz())) 499 if not has_default_dtype: 500 suffixes.append("dtype=" + str(self.dtype)) 501 if not custom_contents_provided: 502 compressed_indices_method, plain_indices_method = { 503 torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), 504 torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), 505 torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), 506 torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), 507 }[self.layout] 508 if self.layout in {torch.sparse_csr, torch.sparse_bsr}: 509 cdimname, pdimname = "row", "column" 510 else: 511 cdimname, pdimname = "column", "row" 512 compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor(" 513 compressed_indices = compressed_indices_method(self).detach() 514 if is_meta: 515 compressed_indices_str = "..." 516 else: 517 compressed_indices_str = _tensor_str( 518 compressed_indices, indent + len(compressed_indices_prefix) 519 ) 520 if compressed_indices.numel() == 0 or is_meta: 521 compressed_indices_str += ", size=" + str( 522 tuple(compressed_indices.shape) 523 ) 524 plain_indices_prefix = f"{pdimname[:3]}_indices=tensor(" 525 plain_indices = plain_indices_method(self).detach() 526 if is_meta: 527 plain_indices_str = "..." 528 else: 529 plain_indices_str = _tensor_str( 530 plain_indices, indent + len(plain_indices_prefix) 531 ) 532 if plain_indices.numel() == 0 or is_meta: 533 plain_indices_str += ", size=" + str(tuple(plain_indices.shape)) 534 values_prefix = "values=tensor(" 535 values = self.values().detach() 536 if is_meta: 537 values_str = "..." 538 else: 539 values_str = _tensor_str(values, indent + len(values_prefix)) 540 if values.numel() == 0 or is_meta: 541 values_str += ", size=" + str(tuple(values.shape)) 542 tensor_str = ( 543 compressed_indices_prefix 544 + compressed_indices_str 545 + "),\n" 546 + " " * indent 547 + plain_indices_prefix 548 + plain_indices_str 549 + "),\n" 550 + " " * indent 551 + values_prefix 552 + values_str 553 + ")" 554 ) 555 elif self.is_quantized: 556 suffixes.append("size=" + str(tuple(self.shape))) 557 if not has_default_dtype: 558 suffixes.append("dtype=" + str(self.dtype)) 559 suffixes.append("quantization_scheme=" + str(self.qscheme())) 560 if ( 561 self.qscheme() == torch.per_tensor_affine 562 or self.qscheme() == torch.per_tensor_symmetric 563 ): 564 suffixes.append("scale=" + str(self.q_scale())) 565 suffixes.append("zero_point=" + str(self.q_zero_point())) 566 elif ( 567 self.qscheme() == torch.per_channel_affine 568 or self.qscheme() == torch.per_channel_symmetric 569 or self.qscheme() == torch.per_channel_affine_float_qparams 570 ): 571 suffixes.append("scale=" + str(self.q_per_channel_scales())) 572 suffixes.append("zero_point=" + str(self.q_per_channel_zero_points())) 573 suffixes.append("axis=" + str(self.q_per_channel_axis())) 574 if not custom_contents_provided: 575 tensor_str = _tensor_str(self.dequantize(), indent) 576 elif self.is_nested: 577 if not custom_contents_provided: 578 579 def indented_str(s, indent): 580 return "\n".join(f" {line}" for line in s.split("\n")) 581 582 strs = ",\n".join( 583 indented_str(str(t), indent + 1) 584 for t in torch.ops.aten.unbind.int(self, 0) 585 ) 586 tensor_str = f"[\n{strs}\n]" 587 elif torch._is_functional_tensor(self): 588 prefix = "_to_functional_tensor(" 589 tensor_str = repr(torch._from_functional_tensor(self)) 590 else: 591 # Circular import problem, so we import it here 592 from torch._subclasses.fake_tensor import FakeTensor 593 594 if self.is_meta or isinstance(self, FakeTensor): 595 suffixes.append("size=" + str(tuple(self.shape))) 596 if self.dtype != torch.get_default_dtype(): 597 suffixes.append("dtype=" + str(self.dtype)) 598 # TODO: This implies that ellipses is valid syntax for allocating 599 # a meta tensor or FakeTensor, which it could be, but it isn't right now 600 if not custom_contents_provided: 601 tensor_str = "..." 602 else: 603 if self.numel() == 0 and not self.is_sparse: 604 # Explicitly print the shape if it is not (0,), to match NumPy behavior 605 if self.dim() != 1: 606 suffixes.append("size=" + str(tuple(self.shape))) 607 608 # In an empty tensor, there are no elements to infer if the dtype 609 # should be int64, so it must be shown explicitly. 610 if self.dtype != torch.get_default_dtype(): 611 suffixes.append("dtype=" + str(self.dtype)) 612 if not custom_contents_provided: 613 tensor_str = "[]" 614 else: 615 if not PRINT_OPTS.edgeitems: 616 suffixes.append("size=" + str(tuple(self.shape))) 617 618 if not has_default_dtype: 619 suffixes.append("dtype=" + str(self.dtype)) 620 621 if not custom_contents_provided: 622 if self.layout != torch.strided: 623 tensor_str = _tensor_str(self.to_dense(), indent) 624 else: 625 tensor_str = _tensor_str(self, indent) 626 627 if self.layout != torch.strided: 628 suffixes.append("layout=" + str(self.layout)) 629 630 # Use inp here to get the original grad_fn and not the one generated by the forward grad 631 # unpacking. 632 grad_fn_name = None 633 try: 634 grad_fn = inp.grad_fn 635 except RuntimeError: 636 # Accessing the grad_fn calls rebasing logic which would cause an error 637 # if that tensor is a view created in no-grad mode modified in-place in 638 # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968 639 grad_fn_name = "Invalid" 640 641 if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined] 642 grad_fn_name = type(grad_fn).__name__ 643 if grad_fn_name == "CppFunction": 644 grad_fn_name = grad_fn.name().rsplit("::", 1)[-1] 645 646 if grad_fn_name is not None: 647 suffixes.append(f"grad_fn=<{grad_fn_name}>") 648 elif inp.requires_grad: 649 suffixes.append("requires_grad=True") 650 651 if self.has_names(): 652 suffixes.append(f"names={self.names}") 653 654 if tangent is not None: 655 suffixes.append(f"tangent={tangent}") 656 657 string_repr = _add_suffixes( 658 prefix + tensor_str, # type: ignore[possibly-undefined] 659 suffixes, 660 indent, 661 force_newline=self.is_sparse, 662 ) 663 664 # Check if this instance is flagged as a parameter and change the repr accordingly. 665 # Unfortunately, this function has to be aware of this detail. 666 # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future, 667 # this should be done for those as well to produce a valid repr. 668 if isinstance(self, torch.nn.Parameter) and not is_plain_tensor: 669 string_repr = f"Parameter({string_repr})" 670 671 return string_repr 672 673 674def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None): 675 level = torch._C._functorch.maybe_get_level(tensor) 676 assert level != -1 677 678 if torch._C._functorch.is_functionaltensor(tensor): 679 # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure 680 # that it's up to date first 681 torch._sync(tensor) 682 683 value = torch._C._functorch.get_unwrapped(tensor) 684 value_repr = repr(value) 685 686 indented_value_repr = textwrap.indent(value_repr, " " * 4) 687 if torch._C._functorch.is_batchedtensor(tensor): 688 bdim = torch._C._functorch.maybe_get_bdim(tensor) 689 assert bdim != -1 690 return ( 691 f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n" 692 f"{indented_value_repr}\n" 693 f")" 694 ) 695 if torch._C._functorch.is_gradtrackingtensor(tensor): 696 return ( 697 f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")" 698 ) 699 if torch._C._functorch.is_functionaltensor(tensor): 700 return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})" 701 702 raise ValueError("We don't know how to print this, please file us an issue") 703 704 705def _str(self, *, tensor_contents=None): 706 with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes(): 707 guard = torch._C._DisableFuncTorch() 708 return _str_intern(self, tensor_contents=tensor_contents) 709