1""" 2This script will generate input-out plots for all of the activation 3functions. These are for use in the documentation, and potentially in 4online tutorials. 5""" 6 7from pathlib import Path 8 9import matplotlib 10from matplotlib import pyplot as plt 11 12import torch 13 14 15matplotlib.use("Agg") 16 17 18# Create a directory for the images, if it doesn't exist 19ACTIVATION_IMAGE_PATH = Path(__file__).parent / "activation_images" 20 21if not ACTIVATION_IMAGE_PATH.exists(): 22 ACTIVATION_IMAGE_PATH.mkdir() 23 24# In a refactor, these ought to go into their own module or entry 25# points so we can generate this list programmatically 26functions = [ 27 torch.nn.ELU(), 28 torch.nn.Hardshrink(), 29 torch.nn.Hardtanh(), 30 torch.nn.Hardsigmoid(), 31 torch.nn.Hardswish(), 32 torch.nn.LeakyReLU(negative_slope=0.1), 33 torch.nn.LogSigmoid(), 34 torch.nn.PReLU(), 35 torch.nn.ReLU(), 36 torch.nn.ReLU6(), 37 torch.nn.RReLU(), 38 torch.nn.SELU(), 39 torch.nn.SiLU(), 40 torch.nn.Mish(), 41 torch.nn.CELU(), 42 torch.nn.GELU(), 43 torch.nn.Sigmoid(), 44 torch.nn.Softplus(), 45 torch.nn.Softshrink(), 46 torch.nn.Softsign(), 47 torch.nn.Tanh(), 48 torch.nn.Tanhshrink(), 49] 50 51 52def plot_function(function, **args): 53 """ 54 Plot a function on the current plot. The additional arguments may 55 be used to specify color, alpha, etc. 56 """ 57 xrange = torch.arange(-7.0, 7.0, 0.01) # We need to go beyond 6 for ReLU6 58 plt.plot(xrange.numpy(), function(xrange).detach().numpy(), **args) 59 60 61# Step through all the functions 62for function in functions: 63 function_name = function._get_name() 64 plot_path = ACTIVATION_IMAGE_PATH / f"{function_name}.png" 65 if not plot_path.exists(): 66 # Start a new plot 67 plt.clf() 68 plt.grid(color="k", alpha=0.2, linestyle="--") 69 70 # Plot the current function 71 plot_function(function) 72 73 plt.title(function) 74 plt.xlabel("Input") 75 plt.ylabel("Output") 76 plt.xlim([-7, 7]) 77 plt.ylim([-7, 7]) 78 79 # And save it 80 plt.savefig(plot_path) 81 print(f"Saved activation image for {function_name} at {plot_path}") 82