xref: /aosp_15_r20/external/pytorch/torch/_tensor_str.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import dataclasses
4import math
5import textwrap
6from typing import Any, Dict, Optional
7
8import torch
9from torch import inf
10
11
12@dataclasses.dataclass
13class __PrinterOptions:
14    precision: int = 4
15    threshold: float = 1000
16    edgeitems: int = 3
17    linewidth: int = 80
18    sci_mode: Optional[bool] = None
19
20
21PRINT_OPTS = __PrinterOptions()
22
23
24# We could use **kwargs, but this will give better docs
25def set_printoptions(
26    precision=None,
27    threshold=None,
28    edgeitems=None,
29    linewidth=None,
30    profile=None,
31    sci_mode=None,
32):
33    r"""Set options for printing. Items shamelessly taken from NumPy
34
35    Args:
36        precision: Number of digits of precision for floating point output
37            (default = 4).
38        threshold: Total number of array elements which trigger summarization
39            rather than full `repr` (default = 1000).
40        edgeitems: Number of array items in summary at beginning and end of
41            each dimension (default = 3).
42        linewidth: The number of characters per line for the purpose of
43            inserting line breaks (default = 80). Thresholded matrices will
44            ignore this parameter.
45        profile: Sane defaults for pretty printing. Can override with any of
46            the above options. (any one of `default`, `short`, `full`)
47        sci_mode: Enable (True) or disable (False) scientific notation. If
48            None (default) is specified, the value is defined by
49            `torch._tensor_str._Formatter`. This value is automatically chosen
50            by the framework.
51
52    Example::
53
54        >>> # Limit the precision of elements
55        >>> torch.set_printoptions(precision=2)
56        >>> torch.tensor([1.12345])
57        tensor([1.12])
58        >>> # Limit the number of elements shown
59        >>> torch.set_printoptions(threshold=5)
60        >>> torch.arange(10)
61        tensor([0, 1, 2, ..., 7, 8, 9])
62        >>> # Restore defaults
63        >>> torch.set_printoptions(profile='default')
64        >>> torch.tensor([1.12345])
65        tensor([1.1235])
66        >>> torch.arange(10)
67        tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
68
69    """
70    if profile is not None:
71        if profile == "default":
72            PRINT_OPTS.precision = 4
73            PRINT_OPTS.threshold = 1000
74            PRINT_OPTS.edgeitems = 3
75            PRINT_OPTS.linewidth = 80
76        elif profile == "short":
77            PRINT_OPTS.precision = 2
78            PRINT_OPTS.threshold = 1000
79            PRINT_OPTS.edgeitems = 2
80            PRINT_OPTS.linewidth = 80
81        elif profile == "full":
82            PRINT_OPTS.precision = 4
83            PRINT_OPTS.threshold = inf
84            PRINT_OPTS.edgeitems = 3
85            PRINT_OPTS.linewidth = 80
86
87    if precision is not None:
88        PRINT_OPTS.precision = precision
89    if threshold is not None:
90        PRINT_OPTS.threshold = threshold
91    if edgeitems is not None:
92        PRINT_OPTS.edgeitems = edgeitems
93    if linewidth is not None:
94        PRINT_OPTS.linewidth = linewidth
95    PRINT_OPTS.sci_mode = sci_mode
96
97
98def get_printoptions() -> Dict[str, Any]:
99    r"""Gets the current options for printing, as a dictionary that
100    can be passed as ``**kwargs`` to set_printoptions().
101    """
102    return dataclasses.asdict(PRINT_OPTS)
103
104
105@contextlib.contextmanager
106def printoptions(**kwargs):
107    r"""Context manager that temporarily changes the print options.  Accepted
108    arguments are same as :func:`set_printoptions`."""
109    old_kwargs = get_printoptions()
110    set_printoptions(**kwargs)
111    try:
112        yield
113    finally:
114        set_printoptions(**old_kwargs)
115
116
117def tensor_totype(t):
118    dtype = (
119        torch.float
120        if (
121            t.is_mps
122            or (t.is_xpu and not torch.xpu.get_device_properties(t.device).has_fp64)
123        )
124        else torch.double
125    )
126    return t.to(dtype=dtype)
127
128
129class _Formatter:
130    def __init__(self, tensor):
131        self.floating_dtype = tensor.dtype.is_floating_point
132        self.int_mode = True
133        self.sci_mode = False
134        self.max_width = 1
135
136        with torch.no_grad():
137            tensor_view = tensor.reshape(-1)
138
139        if not self.floating_dtype:
140            for value in tensor_view:
141                value_str = f"{value}"
142                self.max_width = max(self.max_width, len(value_str))
143
144        else:
145            nonzero_finite_vals = torch.masked_select(
146                tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
147            )
148
149            if nonzero_finite_vals.numel() == 0:
150                # no valid number, do nothing
151                return
152
153            # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
154            nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
155            nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
156            nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
157
158            for value in nonzero_finite_vals:
159                if value != torch.ceil(value):
160                    self.int_mode = False
161                    break
162
163            if self.int_mode:
164                # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
165                # to indicate that the tensor is of floating type. add 1 to the len to account for this.
166                if (
167                    nonzero_finite_max / nonzero_finite_min > 1000.0
168                    or nonzero_finite_max > 1.0e8
169                ):
170                    self.sci_mode = True
171                    for value in nonzero_finite_vals:
172                        value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
173                        self.max_width = max(self.max_width, len(value_str))
174                else:
175                    for value in nonzero_finite_vals:
176                        value_str = f"{value:.0f}"
177                        self.max_width = max(self.max_width, len(value_str) + 1)
178            else:
179                # Check if scientific representation should be used.
180                if (
181                    nonzero_finite_max / nonzero_finite_min > 1000.0
182                    or nonzero_finite_max > 1.0e8
183                    or nonzero_finite_min < 1.0e-4
184                ):
185                    self.sci_mode = True
186                    for value in nonzero_finite_vals:
187                        value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
188                        self.max_width = max(self.max_width, len(value_str))
189                else:
190                    for value in nonzero_finite_vals:
191                        value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
192                        self.max_width = max(self.max_width, len(value_str))
193
194        if PRINT_OPTS.sci_mode is not None:
195            self.sci_mode = PRINT_OPTS.sci_mode
196
197    def width(self):
198        return self.max_width
199
200    def format(self, value):
201        if self.floating_dtype:
202            if self.sci_mode:
203                ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value)
204            elif self.int_mode:
205                ret = f"{value:.0f}"
206                if not (math.isinf(value) or math.isnan(value)):
207                    ret += "."
208            else:
209                ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
210        else:
211            ret = f"{value}"
212        return (self.max_width - len(ret)) * " " + ret
213
214
215def _scalar_str(self, formatter1, formatter2=None):
216    if formatter2 is not None:
217        real_str = _scalar_str(self.real, formatter1)
218        imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
219        # handles negative numbers, +0.0, -0.0
220        if imag_str[0] == "+" or imag_str[0] == "-":
221            return real_str + imag_str
222        else:
223            return real_str + "+" + imag_str
224    else:
225        return formatter1.format(self.item())
226
227
228def _vector_str(self, indent, summarize, formatter1, formatter2=None):
229    # length includes spaces and comma between elements
230    element_length = formatter1.width() + 2
231    if formatter2 is not None:
232        # width for imag_formatter + an extra j for complex
233        element_length += formatter2.width() + 1
234
235    elements_per_line = max(
236        1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
237    )
238
239    def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
240        if formatter2 is not None:
241            real_str = formatter1.format(val.real)
242            imag_str = (formatter2.format(val.imag) + "j").lstrip()
243            # handles negative numbers, +0.0, -0.0
244            if imag_str[0] == "+" or imag_str[0] == "-":
245                return real_str + imag_str
246            else:
247                return real_str + "+" + imag_str
248        else:
249            return formatter1.format(val)
250
251    if summarize and not PRINT_OPTS.edgeitems:
252        # Deal with edge case that negative zero is zero
253        data = ["..."]
254    elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
255        data = (
256            [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()]
257            + [" ..."]
258            + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()]
259        )
260    else:
261        data = [_val_formatter(val) for val in self.tolist()]
262
263    data_lines = [
264        data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)
265    ]
266    lines = [", ".join(line) for line in data_lines]
267    return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
268
269
270# formatter2 is only used for printing complex tensors.
271# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real
272# and tensor.imag respesectively
273def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
274    dim = self.dim()
275
276    if dim == 0:
277        return _scalar_str(self, formatter1, formatter2)
278
279    if dim == 1:
280        return _vector_str(self, indent, summarize, formatter1, formatter2)
281
282    if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
283        slices = (
284            [
285                _tensor_str_with_formatter(
286                    self[i], indent + 1, summarize, formatter1, formatter2
287                )
288                for i in range(0, PRINT_OPTS.edgeitems)
289            ]
290            + ["..."]
291            + [
292                _tensor_str_with_formatter(
293                    self[i], indent + 1, summarize, formatter1, formatter2
294                )
295                for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))
296            ]
297        )
298    else:
299        slices = [
300            _tensor_str_with_formatter(
301                self[i], indent + 1, summarize, formatter1, formatter2
302            )
303            for i in range(0, self.size(0))
304        ]
305
306    tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices)
307    return "[" + tensor_str + "]"
308
309
310def _tensor_str(self, indent):
311    if self.numel() == 0:
312        return "[]"
313
314    if self.has_names():
315        # There are two main codepaths (possibly more) that tensor printing goes through:
316        # - tensor data can fit comfortably on screen
317        # - tensor data needs to be summarized
318        # Some of the codepaths don't fully support named tensors, so we send in
319        # an unnamed tensor to the formatting code as a workaround.
320        self = self.rename(None)
321
322    summarize = self.numel() > PRINT_OPTS.threshold
323
324    if self._is_zerotensor():
325        self = self.clone()
326
327    # handle the negative bit
328    if self.is_neg():
329        self = self.resolve_neg()
330
331    if self.dtype in [
332        torch.float16,
333        torch.bfloat16,
334        torch.float8_e5m2,
335        torch.float8_e5m2fnuz,
336        torch.float8_e4m3fn,
337        torch.float8_e4m3fnuz,
338    ]:
339        self = self.float()
340
341    if self.dtype is torch.complex32:
342        self = self.cfloat()
343
344    if self.dtype.is_complex:
345        # handle the conjugate bit
346        self = self.resolve_conj()
347        real_formatter = _Formatter(
348            get_summarized_data(self.real) if summarize else self.real
349        )
350        imag_formatter = _Formatter(
351            get_summarized_data(self.imag) if summarize else self.imag
352        )
353        return _tensor_str_with_formatter(
354            self, indent, summarize, real_formatter, imag_formatter
355        )
356    else:
357        formatter = _Formatter(get_summarized_data(self) if summarize else self)
358        return _tensor_str_with_formatter(self, indent, summarize, formatter)
359
360
361def _add_suffixes(tensor_str, suffixes, indent, force_newline):
362    tensor_strs = [tensor_str]
363    last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1
364    for suffix in suffixes:
365        suffix_len = len(suffix)
366        if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
367            tensor_strs.append(",\n" + " " * indent + suffix)
368            last_line_len = indent + suffix_len
369            force_newline = False
370        else:
371            tensor_strs.append(", " + suffix)
372            last_line_len += suffix_len + 2
373    tensor_strs.append(")")
374    return "".join(tensor_strs)
375
376
377def get_summarized_data(self):
378    dim = self.dim()
379    if dim == 0:
380        return self
381    if dim == 1:
382        if self.size(0) > 2 * PRINT_OPTS.edgeitems:
383            return torch.cat(
384                (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])
385            )
386        else:
387            return self
388    if not PRINT_OPTS.edgeitems:
389        return self.new_empty([0] * self.dim())
390    elif self.size(0) > 2 * PRINT_OPTS.edgeitems:
391        start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
392        end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]
393        return torch.stack([get_summarized_data(x) for x in (start + end)])
394    else:
395        return torch.stack([get_summarized_data(x) for x in self])
396
397
398def _str_intern(inp, *, tensor_contents=None):
399    if torch._C._functorch.is_functorch_wrapped_tensor(inp):
400        return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents)
401    is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
402    if inp.is_nested:
403        prefix = "nested_tensor("
404    elif is_plain_tensor:
405        prefix = "tensor("
406    else:
407        prefix = f"{type(inp).__name__}("
408    indent = len(prefix)
409    suffixes = []
410    custom_contents_provided = tensor_contents is not None
411    if custom_contents_provided:
412        tensor_str = tensor_contents
413
414    # This is used to extract the primal value and thus disable the forward AD
415    # within this function.
416    # TODO(albanD) This needs to be updated when more than one level is supported
417    self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
418
419    # Note [Print tensor device]:
420    # A general logic here is we only print device when it doesn't match
421    # the device specified in default tensor type.
422    # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
423    # torch._C._get_default_device() only returns either cpu or cuda.
424    # In other cases, we don't have a way to set them as default yet,
425    # and we should always print out device for them.
426    if (
427        self.device.type != torch._C._get_default_device()
428        or (
429            self.device.type == "cuda"
430            and torch.cuda.current_device() != self.device.index
431        )
432        or (self.device.type == "mps")
433    ):
434        suffixes.append("device='" + str(self.device) + "'")
435
436    # Tensor printing performs tensor operations like slice, indexing, etc to make it in a
437    # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence,
438    # to avoid compilations, copying the tensor to cpu before printing.
439    if self.device.type in ["xla", "lazy", "ipu", "mtia"]:
440        self = self.to("cpu")
441
442    # TODO: add an API to map real -> complex dtypes
443    _default_complex_dtype = (
444        torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
445    )
446    has_default_dtype = self.dtype in (
447        torch.get_default_dtype(),
448        _default_complex_dtype,
449        torch.int64,
450        torch.bool,
451    )
452    if self.is_sparse:
453        suffixes.append("size=" + str(tuple(self.shape)))
454        from torch._subclasses.fake_tensor import FakeTensor
455
456        is_meta = self.is_meta or isinstance(self, FakeTensor)
457        if not is_meta:
458            suffixes.append("nnz=" + str(self._nnz()))
459        if not has_default_dtype:
460            suffixes.append("dtype=" + str(self.dtype))
461        if not custom_contents_provided:
462            indices_prefix = "indices=tensor("
463            indices = self._indices().detach()
464            if is_meta:
465                indices_str = "..."
466            else:
467                indices_str = _tensor_str(indices, indent + len(indices_prefix))
468            if is_meta or indices.numel() == 0:
469                indices_str += ", size=" + str(tuple(indices.shape))
470            values_prefix = "values=tensor("
471            values = self._values().detach()
472            if is_meta:
473                values_str = "..."
474            else:
475                values_str = _tensor_str(values, indent + len(values_prefix))
476            if is_meta or values.numel() == 0:
477                values_str += ", size=" + str(tuple(values.shape))
478            tensor_str = (
479                indices_prefix
480                + indices_str
481                + "),\n"
482                + " " * indent
483                + values_prefix
484                + values_str
485                + ")"
486            )
487    elif self.layout in {
488        torch.sparse_csr,
489        torch.sparse_csc,
490        torch.sparse_bsr,
491        torch.sparse_bsc,
492    }:
493        from torch._subclasses.fake_tensor import FakeTensor
494
495        suffixes.append("size=" + str(tuple(self.shape)))
496        is_meta = self.is_meta or isinstance(self, FakeTensor)
497        if not is_meta:
498            suffixes.append("nnz=" + str(self._nnz()))
499        if not has_default_dtype:
500            suffixes.append("dtype=" + str(self.dtype))
501        if not custom_contents_provided:
502            compressed_indices_method, plain_indices_method = {
503                torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
504                torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
505                torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
506                torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
507            }[self.layout]
508            if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
509                cdimname, pdimname = "row", "column"
510            else:
511                cdimname, pdimname = "column", "row"
512            compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
513            compressed_indices = compressed_indices_method(self).detach()
514            if is_meta:
515                compressed_indices_str = "..."
516            else:
517                compressed_indices_str = _tensor_str(
518                    compressed_indices, indent + len(compressed_indices_prefix)
519                )
520            if compressed_indices.numel() == 0 or is_meta:
521                compressed_indices_str += ", size=" + str(
522                    tuple(compressed_indices.shape)
523                )
524            plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
525            plain_indices = plain_indices_method(self).detach()
526            if is_meta:
527                plain_indices_str = "..."
528            else:
529                plain_indices_str = _tensor_str(
530                    plain_indices, indent + len(plain_indices_prefix)
531                )
532            if plain_indices.numel() == 0 or is_meta:
533                plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
534            values_prefix = "values=tensor("
535            values = self.values().detach()
536            if is_meta:
537                values_str = "..."
538            else:
539                values_str = _tensor_str(values, indent + len(values_prefix))
540            if values.numel() == 0 or is_meta:
541                values_str += ", size=" + str(tuple(values.shape))
542            tensor_str = (
543                compressed_indices_prefix
544                + compressed_indices_str
545                + "),\n"
546                + " " * indent
547                + plain_indices_prefix
548                + plain_indices_str
549                + "),\n"
550                + " " * indent
551                + values_prefix
552                + values_str
553                + ")"
554            )
555    elif self.is_quantized:
556        suffixes.append("size=" + str(tuple(self.shape)))
557        if not has_default_dtype:
558            suffixes.append("dtype=" + str(self.dtype))
559        suffixes.append("quantization_scheme=" + str(self.qscheme()))
560        if (
561            self.qscheme() == torch.per_tensor_affine
562            or self.qscheme() == torch.per_tensor_symmetric
563        ):
564            suffixes.append("scale=" + str(self.q_scale()))
565            suffixes.append("zero_point=" + str(self.q_zero_point()))
566        elif (
567            self.qscheme() == torch.per_channel_affine
568            or self.qscheme() == torch.per_channel_symmetric
569            or self.qscheme() == torch.per_channel_affine_float_qparams
570        ):
571            suffixes.append("scale=" + str(self.q_per_channel_scales()))
572            suffixes.append("zero_point=" + str(self.q_per_channel_zero_points()))
573            suffixes.append("axis=" + str(self.q_per_channel_axis()))
574        if not custom_contents_provided:
575            tensor_str = _tensor_str(self.dequantize(), indent)
576    elif self.is_nested:
577        if not custom_contents_provided:
578
579            def indented_str(s, indent):
580                return "\n".join(f"  {line}" for line in s.split("\n"))
581
582            strs = ",\n".join(
583                indented_str(str(t), indent + 1)
584                for t in torch.ops.aten.unbind.int(self, 0)
585            )
586            tensor_str = f"[\n{strs}\n]"
587    elif torch._is_functional_tensor(self):
588        prefix = "_to_functional_tensor("
589        tensor_str = repr(torch._from_functional_tensor(self))
590    else:
591        # Circular import problem, so we import it here
592        from torch._subclasses.fake_tensor import FakeTensor
593
594        if self.is_meta or isinstance(self, FakeTensor):
595            suffixes.append("size=" + str(tuple(self.shape)))
596            if self.dtype != torch.get_default_dtype():
597                suffixes.append("dtype=" + str(self.dtype))
598            # TODO: This implies that ellipses is valid syntax for allocating
599            # a meta tensor or FakeTensor, which it could be, but it isn't right now
600            if not custom_contents_provided:
601                tensor_str = "..."
602        else:
603            if self.numel() == 0 and not self.is_sparse:
604                # Explicitly print the shape if it is not (0,), to match NumPy behavior
605                if self.dim() != 1:
606                    suffixes.append("size=" + str(tuple(self.shape)))
607
608                # In an empty tensor, there are no elements to infer if the dtype
609                # should be int64, so it must be shown explicitly.
610                if self.dtype != torch.get_default_dtype():
611                    suffixes.append("dtype=" + str(self.dtype))
612                if not custom_contents_provided:
613                    tensor_str = "[]"
614            else:
615                if not PRINT_OPTS.edgeitems:
616                    suffixes.append("size=" + str(tuple(self.shape)))
617
618                if not has_default_dtype:
619                    suffixes.append("dtype=" + str(self.dtype))
620
621                if not custom_contents_provided:
622                    if self.layout != torch.strided:
623                        tensor_str = _tensor_str(self.to_dense(), indent)
624                    else:
625                        tensor_str = _tensor_str(self, indent)
626
627    if self.layout != torch.strided:
628        suffixes.append("layout=" + str(self.layout))
629
630    # Use inp here to get the original grad_fn and not the one generated by the forward grad
631    # unpacking.
632    grad_fn_name = None
633    try:
634        grad_fn = inp.grad_fn
635    except RuntimeError:
636        # Accessing the grad_fn calls rebasing logic which would cause an error
637        # if that tensor is a view created in no-grad mode modified in-place in
638        # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
639        grad_fn_name = "Invalid"
640
641    if grad_fn_name is None and grad_fn is not None:  # type: ignore[possibly-undefined]
642        grad_fn_name = type(grad_fn).__name__
643        if grad_fn_name == "CppFunction":
644            grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
645
646    if grad_fn_name is not None:
647        suffixes.append(f"grad_fn=<{grad_fn_name}>")
648    elif inp.requires_grad:
649        suffixes.append("requires_grad=True")
650
651    if self.has_names():
652        suffixes.append(f"names={self.names}")
653
654    if tangent is not None:
655        suffixes.append(f"tangent={tangent}")
656
657    string_repr = _add_suffixes(
658        prefix + tensor_str,  # type: ignore[possibly-undefined]
659        suffixes,
660        indent,
661        force_newline=self.is_sparse,
662    )
663
664    # Check if this instance is flagged as a parameter and change the repr accordingly.
665    # Unfortunately, this function has to be aware of this detail.
666    # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
667    # this should be done for those as well to produce a valid repr.
668    if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
669        string_repr = f"Parameter({string_repr})"
670
671    return string_repr
672
673
674def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
675    level = torch._C._functorch.maybe_get_level(tensor)
676    assert level != -1
677
678    if torch._C._functorch.is_functionaltensor(tensor):
679        # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
680        # that it's up to date first
681        torch._sync(tensor)
682
683    value = torch._C._functorch.get_unwrapped(tensor)
684    value_repr = repr(value)
685
686    indented_value_repr = textwrap.indent(value_repr, " " * 4)
687    if torch._C._functorch.is_batchedtensor(tensor):
688        bdim = torch._C._functorch.maybe_get_bdim(tensor)
689        assert bdim != -1
690        return (
691            f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n"
692            f"{indented_value_repr}\n"
693            f")"
694        )
695    if torch._C._functorch.is_gradtrackingtensor(tensor):
696        return (
697            f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")"
698        )
699    if torch._C._functorch.is_functionaltensor(tensor):
700        return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
701
702    raise ValueError("We don't know how to print this, please file us an issue")
703
704
705def _str(self, *, tensor_contents=None):
706    with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
707        guard = torch._C._DisableFuncTorch()
708        return _str_intern(self, tensor_contents=tensor_contents)
709