xref: /aosp_15_r20/external/pytorch/test/distributions/test_transforms.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: distributions"]
2
3import io
4from numbers import Number
5
6import pytest
7
8import torch
9from torch.autograd import grad
10from torch.autograd.functional import jacobian
11from torch.distributions import (
12    constraints,
13    Dirichlet,
14    Independent,
15    Normal,
16    TransformedDistribution,
17)
18from torch.distributions.transforms import (
19    _InverseTransform,
20    AbsTransform,
21    AffineTransform,
22    ComposeTransform,
23    CorrCholeskyTransform,
24    CumulativeDistributionTransform,
25    ExpTransform,
26    identity_transform,
27    IndependentTransform,
28    LowerCholeskyTransform,
29    PositiveDefiniteTransform,
30    PowerTransform,
31    ReshapeTransform,
32    SigmoidTransform,
33    SoftmaxTransform,
34    SoftplusTransform,
35    StickBreakingTransform,
36    TanhTransform,
37    Transform,
38)
39from torch.distributions.utils import tril_matrix_to_vec, vec_to_tril_matrix
40from torch.testing._internal.common_utils import run_tests
41
42
43def get_transforms(cache_size):
44    transforms = [
45        AbsTransform(cache_size=cache_size),
46        ExpTransform(cache_size=cache_size),
47        PowerTransform(exponent=2, cache_size=cache_size),
48        PowerTransform(exponent=-2, cache_size=cache_size),
49        PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
50        PowerTransform(exponent=torch.tensor(5.0).normal_(), cache_size=cache_size),
51        SigmoidTransform(cache_size=cache_size),
52        TanhTransform(cache_size=cache_size),
53        AffineTransform(0, 1, cache_size=cache_size),
54        AffineTransform(1, -2, cache_size=cache_size),
55        AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size),
56        AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size),
57        SoftmaxTransform(cache_size=cache_size),
58        SoftplusTransform(cache_size=cache_size),
59        StickBreakingTransform(cache_size=cache_size),
60        LowerCholeskyTransform(cache_size=cache_size),
61        CorrCholeskyTransform(cache_size=cache_size),
62        PositiveDefiniteTransform(cache_size=cache_size),
63        ComposeTransform(
64            [
65                AffineTransform(
66                    torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
67                ),
68            ]
69        ),
70        ComposeTransform(
71            [
72                AffineTransform(
73                    torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
74                ),
75                ExpTransform(cache_size=cache_size),
76            ]
77        ),
78        ComposeTransform(
79            [
80                AffineTransform(0, 1, cache_size=cache_size),
81                AffineTransform(
82                    torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
83                ),
84                AffineTransform(1, -2, cache_size=cache_size),
85                AffineTransform(
86                    torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size
87                ),
88            ]
89        ),
90        ReshapeTransform((4, 5), (2, 5, 2)),
91        IndependentTransform(
92            AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), 1
93        ),
94        CumulativeDistributionTransform(Normal(0, 1)),
95    ]
96    transforms += [t.inv for t in transforms]
97    return transforms
98
99
100def reshape_transform(transform, shape):
101    # Needed to squash batch dims for testing jacobian
102    if isinstance(transform, AffineTransform):
103        if isinstance(transform.loc, Number):
104            return transform
105        try:
106            return AffineTransform(
107                transform.loc.expand(shape),
108                transform.scale.expand(shape),
109                cache_size=transform._cache_size,
110            )
111        except RuntimeError:
112            return AffineTransform(
113                transform.loc.reshape(shape),
114                transform.scale.reshape(shape),
115                cache_size=transform._cache_size,
116            )
117    if isinstance(transform, ComposeTransform):
118        reshaped_parts = []
119        for p in transform.parts:
120            reshaped_parts.append(reshape_transform(p, shape))
121        return ComposeTransform(reshaped_parts, cache_size=transform._cache_size)
122    if isinstance(transform.inv, AffineTransform):
123        return reshape_transform(transform.inv, shape).inv
124    if isinstance(transform.inv, ComposeTransform):
125        return reshape_transform(transform.inv, shape).inv
126    return transform
127
128
129# Generate pytest ids
130def transform_id(x):
131    assert isinstance(x, Transform)
132    name = (
133        f"Inv({type(x._inv).__name__})"
134        if isinstance(x, _InverseTransform)
135        else f"{type(x).__name__}"
136    )
137    return f"{name}(cache_size={x._cache_size})"
138
139
140def generate_data(transform):
141    torch.manual_seed(1)
142    while isinstance(transform, IndependentTransform):
143        transform = transform.base_transform
144    if isinstance(transform, ReshapeTransform):
145        return torch.randn(transform.in_shape)
146    if isinstance(transform.inv, ReshapeTransform):
147        return torch.randn(transform.inv.out_shape)
148    domain = transform.domain
149    while (
150        isinstance(domain, constraints.independent)
151        and domain is not constraints.real_vector
152    ):
153        domain = domain.base_constraint
154    codomain = transform.codomain
155    x = torch.empty(4, 5)
156    positive_definite_constraints = [
157        constraints.lower_cholesky,
158        constraints.positive_definite,
159    ]
160    if domain in positive_definite_constraints:
161        x = torch.randn(6, 6)
162        x = x.tril(-1) + x.diag().exp().diag_embed()
163        if domain is constraints.positive_definite:
164            return x @ x.T
165        return x
166    elif codomain in positive_definite_constraints:
167        return torch.randn(6, 6)
168    elif domain is constraints.real:
169        return x.normal_()
170    elif domain is constraints.real_vector:
171        # For corr_cholesky the last dim in the vector
172        # must be of size (dim * dim) // 2
173        x = torch.empty(3, 6)
174        x = x.normal_()
175        return x
176    elif domain is constraints.positive:
177        return x.normal_().exp()
178    elif domain is constraints.unit_interval:
179        return x.uniform_()
180    elif isinstance(domain, constraints.interval):
181        x = x.uniform_()
182        x = x.mul_(domain.upper_bound - domain.lower_bound).add_(domain.lower_bound)
183        return x
184    elif domain is constraints.simplex:
185        x = x.normal_().exp()
186        x /= x.sum(-1, True)
187        return x
188    elif domain is constraints.corr_cholesky:
189        x = torch.empty(4, 5, 5)
190        x = x.normal_().tril()
191        x /= x.norm(dim=-1, keepdim=True)
192        x.diagonal(dim1=-1).copy_(x.diagonal(dim1=-1).abs())
193        return x
194    raise ValueError(f"Unsupported domain: {domain}")
195
196
197TRANSFORMS_CACHE_ACTIVE = get_transforms(cache_size=1)
198TRANSFORMS_CACHE_INACTIVE = get_transforms(cache_size=0)
199ALL_TRANSFORMS = (
200    TRANSFORMS_CACHE_ACTIVE + TRANSFORMS_CACHE_INACTIVE + [identity_transform]
201)
202
203
204@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
205def test_inv_inv(transform, ids=transform_id):
206    assert transform.inv.inv is transform
207
208
209@pytest.mark.parametrize("x", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
210@pytest.mark.parametrize("y", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
211def test_equality(x, y):
212    if x is y:
213        assert x == y
214    else:
215        assert x != y
216    assert identity_transform == identity_transform.inv
217
218
219@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
220def test_with_cache(transform):
221    if transform._cache_size == 0:
222        transform = transform.with_cache(1)
223    assert transform._cache_size == 1
224    x = generate_data(transform).requires_grad_()
225    try:
226        y = transform(x)
227    except NotImplementedError:
228        pytest.skip("Not implemented.")
229    y2 = transform(x)
230    assert y2 is y
231
232
233@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
234@pytest.mark.parametrize("test_cached", [True, False])
235def test_forward_inverse(transform, test_cached):
236    x = generate_data(transform).requires_grad_()
237    assert transform.domain.check(x).all()  # verify that the input data are valid
238    try:
239        y = transform(x)
240    except NotImplementedError:
241        pytest.skip("Not implemented.")
242    assert y.shape == transform.forward_shape(x.shape)
243    if test_cached:
244        x2 = transform.inv(y)  # should be implemented at least by caching
245    else:
246        try:
247            x2 = transform.inv(y.clone())  # bypass cache
248        except NotImplementedError:
249            pytest.skip("Not implemented.")
250    assert x2.shape == transform.inverse_shape(y.shape)
251    y2 = transform(x2)
252    if transform.bijective:
253        # verify function inverse
254        assert torch.allclose(x2, x, atol=1e-4, equal_nan=True), "\n".join(
255            [
256                f"{transform} t.inv(t(-)) error",
257                f"x = {x}",
258                f"y = t(x) = {y}",
259                f"x2 = t.inv(y) = {x2}",
260            ]
261        )
262    else:
263        # verify weaker function pseudo-inverse
264        assert torch.allclose(y2, y, atol=1e-4, equal_nan=True), "\n".join(
265            [
266                f"{transform} t(t.inv(t(-))) error",
267                f"x = {x}",
268                f"y = t(x) = {y}",
269                f"x2 = t.inv(y) = {x2}",
270                f"y2 = t(x2) = {y2}",
271            ]
272        )
273
274
275def test_compose_transform_shapes():
276    transform0 = ExpTransform()
277    transform1 = SoftmaxTransform()
278    transform2 = LowerCholeskyTransform()
279
280    assert transform0.event_dim == 0
281    assert transform1.event_dim == 1
282    assert transform2.event_dim == 2
283    assert ComposeTransform([transform0, transform1]).event_dim == 1
284    assert ComposeTransform([transform0, transform2]).event_dim == 2
285    assert ComposeTransform([transform1, transform2]).event_dim == 2
286
287
288transform0 = ExpTransform()
289transform1 = SoftmaxTransform()
290transform2 = LowerCholeskyTransform()
291base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
292base_dist1 = Dirichlet(torch.ones(4, 4))
293base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
294
295
296@pytest.mark.parametrize(
297    ("batch_shape", "event_shape", "dist"),
298    [
299        ((4, 4), (), base_dist0),
300        ((4,), (4,), base_dist1),
301        ((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
302        ((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
303        ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
304        ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
305        ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
306        ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
307        ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
308        ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
309        ((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
310        ((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
311        ((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
312        ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
313        ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
314        ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
315        ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
316        ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
317        ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
318        ((3, 4, 4), (), base_dist2),
319        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
320        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
321        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
322        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
323        ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
324    ],
325)
326def test_transformed_distribution_shapes(batch_shape, event_shape, dist):
327    assert dist.batch_shape == batch_shape
328    assert dist.event_shape == event_shape
329    x = dist.rsample()
330    try:
331        dist.log_prob(x)  # this should not crash
332    except NotImplementedError:
333        pytest.skip("Not implemented.")
334
335
336@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
337def test_jit_fwd(transform):
338    x = generate_data(transform).requires_grad_()
339
340    def f(x):
341        return transform(x)
342
343    try:
344        traced_f = torch.jit.trace(f, (x,))
345    except NotImplementedError:
346        pytest.skip("Not implemented.")
347
348    # check on different inputs
349    x = generate_data(transform).requires_grad_()
350    assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
351
352
353@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
354def test_jit_inv(transform):
355    y = generate_data(transform.inv).requires_grad_()
356
357    def f(y):
358        return transform.inv(y)
359
360    try:
361        traced_f = torch.jit.trace(f, (y,))
362    except NotImplementedError:
363        pytest.skip("Not implemented.")
364
365    # check on different inputs
366    y = generate_data(transform.inv).requires_grad_()
367    assert torch.allclose(f(y), traced_f(y), atol=1e-5, equal_nan=True)
368
369
370@pytest.mark.parametrize("transform", TRANSFORMS_CACHE_INACTIVE, ids=transform_id)
371def test_jit_jacobian(transform):
372    x = generate_data(transform).requires_grad_()
373
374    def f(x):
375        y = transform(x)
376        return transform.log_abs_det_jacobian(x, y)
377
378    try:
379        traced_f = torch.jit.trace(f, (x,))
380    except NotImplementedError:
381        pytest.skip("Not implemented.")
382
383    # check on different inputs
384    x = generate_data(transform).requires_grad_()
385    assert torch.allclose(f(x), traced_f(x), atol=1e-5, equal_nan=True)
386
387
388@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
389def test_jacobian(transform):
390    x = generate_data(transform)
391    try:
392        y = transform(x)
393        actual = transform.log_abs_det_jacobian(x, y)
394    except NotImplementedError:
395        pytest.skip("Not implemented.")
396    # Test shape
397    target_shape = x.shape[: x.dim() - transform.domain.event_dim]
398    assert actual.shape == target_shape
399
400    # Expand if required
401    transform = reshape_transform(transform, x.shape)
402    ndims = len(x.shape)
403    event_dim = ndims - transform.domain.event_dim
404    x_ = x.view((-1,) + x.shape[event_dim:])
405    n = x_.shape[0]
406    # Reshape to squash batch dims to a single batch dim
407    transform = reshape_transform(transform, x_.shape)
408
409    # 1. Transforms with unit jacobian
410    if isinstance(transform, ReshapeTransform) or isinstance(
411        transform.inv, ReshapeTransform
412    ):
413        expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
414        expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
415    # 2. Transforms with 0 off-diagonal elements
416    elif transform.domain.event_dim == 0:
417        jac = jacobian(transform, x_)
418        # assert off-diagonal elements are zero
419        assert torch.allclose(jac, jac.diagonal().diag_embed())
420        expected = jac.diagonal().abs().log().reshape(x.shape)
421    # 3. Transforms with non-0 off-diagonal elements
422    else:
423        if isinstance(transform, CorrCholeskyTransform):
424            jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
425        elif isinstance(transform.inv, CorrCholeskyTransform):
426            jac = jacobian(
427                lambda x: transform(vec_to_tril_matrix(x, diag=-1)),
428                tril_matrix_to_vec(x_, diag=-1),
429            )
430        elif isinstance(transform, StickBreakingTransform):
431            jac = jacobian(lambda x: transform(x)[..., :-1], x_)
432        else:
433            jac = jacobian(transform, x_)
434
435        # Note that jacobian will have shape (batch_dims, y_event_dims, batch_dims, x_event_dims)
436        # However, batches are independent so this can be converted into a (batch_dims, event_dims, event_dims)
437        # after reshaping the event dims (see above) to give a batched square matrix whose determinant
438        # can be computed.
439        gather_idx_shape = list(jac.shape)
440        gather_idx_shape[-2] = 1
441        gather_idxs = (
442            torch.arange(n)
443            .reshape((n,) + (1,) * (len(jac.shape) - 1))
444            .expand(gather_idx_shape)
445        )
446        jac = jac.gather(-2, gather_idxs).squeeze(-2)
447        out_ndims = jac.shape[-2]
448        jac = jac[
449            ..., :out_ndims
450        ]  # Remove extra zero-valued dims (for inverse stick-breaking).
451        expected = torch.slogdet(jac).logabsdet
452
453    assert torch.allclose(actual, expected, atol=1e-5)
454
455
456@pytest.mark.parametrize(
457    "event_dims", [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)], ids=str
458)
459def test_compose_affine(event_dims):
460    transforms = [
461        AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims
462    ]
463    transform = ComposeTransform(transforms)
464    assert transform.codomain.event_dim == max(event_dims)
465    assert transform.domain.event_dim == max(event_dims)
466
467    base_dist = Normal(0, 1)
468    if transform.domain.event_dim:
469        base_dist = base_dist.expand((1,) * transform.domain.event_dim)
470    dist = TransformedDistribution(base_dist, transform.parts)
471    assert dist.support.event_dim == max(event_dims)
472
473    base_dist = Dirichlet(torch.ones(5))
474    if transform.domain.event_dim > 1:
475        base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
476    dist = TransformedDistribution(base_dist, transforms)
477    assert dist.support.event_dim == max(1, *event_dims)
478
479
480@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
481def test_compose_reshape(batch_shape):
482    transforms = [
483        ReshapeTransform((), ()),
484        ReshapeTransform((2,), (1, 2)),
485        ReshapeTransform((3, 1, 2), (6,)),
486        ReshapeTransform((6,), (2, 3)),
487    ]
488    transform = ComposeTransform(transforms)
489    assert transform.codomain.event_dim == 2
490    assert transform.domain.event_dim == 2
491    data = torch.randn(batch_shape + (3, 2))
492    assert transform(data).shape == batch_shape + (2, 3)
493
494    dist = TransformedDistribution(Normal(data, 1), transforms)
495    assert dist.batch_shape == batch_shape
496    assert dist.event_shape == (2, 3)
497    assert dist.support.event_dim == 2
498
499
500@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str)
501@pytest.mark.parametrize("transform_dim", [0, 1, 2])
502@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
503@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
504@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
505def test_transformed_distribution(
506    base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape
507):
508    shape = torch.Size([2, 3, 4, 5])
509    base_dist = Normal(0, 1)
510    base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim :])
511    if base_event_dim:
512        base_dist = Independent(base_dist, base_event_dim)
513    transforms = [
514        AffineTransform(torch.zeros(shape[4 - transform_dim :]), 1),
515        ReshapeTransform((4, 5), (20,)),
516        ReshapeTransform((3, 20), (6, 10)),
517    ]
518    transforms = transforms[:num_transforms]
519    transform = ComposeTransform(transforms)
520
521    # Check validation in .__init__().
522    if base_batch_dim + base_event_dim < transform.domain.event_dim:
523        with pytest.raises(ValueError):
524            TransformedDistribution(base_dist, transforms)
525        return
526    d = TransformedDistribution(base_dist, transforms)
527
528    # Check sampling is sufficiently expanded.
529    x = d.sample(sample_shape)
530    assert x.shape == sample_shape + d.batch_shape + d.event_shape
531    num_unique = len(set(x.reshape(-1).tolist()))
532    assert num_unique >= 0.9 * x.numel()
533
534    # Check log_prob shape on full samples.
535    log_prob = d.log_prob(x)
536    assert log_prob.shape == sample_shape + d.batch_shape
537
538    # Check log_prob shape on partial samples.
539    y = x
540    while y.dim() > len(d.event_shape):
541        y = y[0]
542    log_prob = d.log_prob(y)
543    assert log_prob.shape == d.batch_shape
544
545
546def test_save_load_transform():
547    # Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check
548    # that `__getstate__` correctly handles the weakref, and that we can evaluate the density after.
549    dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)])
550    x = torch.linspace(0, 1, 10)
551    log_prob = dist.log_prob(x)
552    stream = io.BytesIO()
553    torch.save(dist, stream)
554    stream.seek(0)
555    other = torch.load(stream)
556    assert torch.allclose(log_prob, other.log_prob(x))
557
558
559@pytest.mark.parametrize("transform", ALL_TRANSFORMS, ids=transform_id)
560def test_transform_sign(transform: Transform):
561    try:
562        sign = transform.sign
563    except NotImplementedError:
564        pytest.skip("Not implemented.")
565
566    x = generate_data(transform).requires_grad_()
567    y = transform(x).sum()
568    (derivatives,) = grad(y, [x])
569    assert torch.less(torch.as_tensor(0.0), derivatives * sign).all()
570
571
572if __name__ == "__main__":
573    run_tests()
574