xref: /aosp_15_r20/external/pytorch/docs/source/scripts/build_activation_images.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2This script will generate input-out plots for all of the activation
3functions. These are for use in the documentation, and potentially in
4online tutorials.
5"""
6
7from pathlib import Path
8
9import matplotlib
10from matplotlib import pyplot as plt
11
12import torch
13
14
15matplotlib.use("Agg")
16
17
18# Create a directory for the images, if it doesn't exist
19ACTIVATION_IMAGE_PATH = Path(__file__).parent / "activation_images"
20
21if not ACTIVATION_IMAGE_PATH.exists():
22    ACTIVATION_IMAGE_PATH.mkdir()
23
24# In a refactor, these ought to go into their own module or entry
25# points so we can generate this list programmatically
26functions = [
27    torch.nn.ELU(),
28    torch.nn.Hardshrink(),
29    torch.nn.Hardtanh(),
30    torch.nn.Hardsigmoid(),
31    torch.nn.Hardswish(),
32    torch.nn.LeakyReLU(negative_slope=0.1),
33    torch.nn.LogSigmoid(),
34    torch.nn.PReLU(),
35    torch.nn.ReLU(),
36    torch.nn.ReLU6(),
37    torch.nn.RReLU(),
38    torch.nn.SELU(),
39    torch.nn.SiLU(),
40    torch.nn.Mish(),
41    torch.nn.CELU(),
42    torch.nn.GELU(),
43    torch.nn.Sigmoid(),
44    torch.nn.Softplus(),
45    torch.nn.Softshrink(),
46    torch.nn.Softsign(),
47    torch.nn.Tanh(),
48    torch.nn.Tanhshrink(),
49]
50
51
52def plot_function(function, **args):
53    """
54    Plot a function on the current plot. The additional arguments may
55    be used to specify color, alpha, etc.
56    """
57    xrange = torch.arange(-7.0, 7.0, 0.01)  # We need to go beyond 6 for ReLU6
58    plt.plot(xrange.numpy(), function(xrange).detach().numpy(), **args)
59
60
61# Step through all the functions
62for function in functions:
63    function_name = function._get_name()
64    plot_path = ACTIVATION_IMAGE_PATH / f"{function_name}.png"
65    if not plot_path.exists():
66        # Start a new plot
67        plt.clf()
68        plt.grid(color="k", alpha=0.2, linestyle="--")
69
70        # Plot the current function
71        plot_function(function)
72
73        plt.title(function)
74        plt.xlabel("Input")
75        plt.ylabel("Output")
76        plt.xlim([-7, 7])
77        plt.ylim([-7, 7])
78
79        # And save it
80        plt.savefig(plot_path)
81        print(f"Saved activation image for {function_name} at {plot_path}")
82