1# This example was adapated from https://github.com/muhrin/milad 2# It is licensed under the GLPv3 license. You can find a copy of it 3# here: https://www.gnu.org/licenses/gpl-3.0.en.html . 4 5import torch 6from torch import nn 7from torch.func import jacrev, vmap 8from torch.nn.functional import mse_loss 9 10 11sigma = 0.5 12epsilon = 4.0 13 14 15def lennard_jones(r): 16 return epsilon * ((sigma / r) ** 12 - (sigma / r) ** 6) 17 18 19def lennard_jones_force(r): 20 """Get magnitude of LJ force""" 21 return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)) 22 23 24training_size = 1000 25r = torch.linspace(0.5, 2 * sigma, steps=training_size, requires_grad=True) 26 27# Create a bunch of vectors that point along positive-x 28drs = torch.outer(r, torch.tensor([1.0, 0, 0])) 29norms = torch.norm(drs, dim=1).reshape(-1, 1) 30# Create training energies 31training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1) 32# Create forces with random direction vectors 33training_forces = torch.stack( 34 [force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)] 35) 36 37model = nn.Sequential( 38 nn.Linear(1, 16), 39 nn.Tanh(), 40 nn.Linear(16, 16), 41 nn.Tanh(), 42 nn.Linear(16, 16), 43 nn.Tanh(), 44 nn.Linear(16, 16), 45 nn.Tanh(), 46 nn.Linear(16, 1), 47) 48 49 50def make_prediction(model, drs): 51 norms = torch.norm(drs, dim=1).reshape(-1, 1) 52 energies = model(norms) 53 54 network_derivs = vmap(jacrev(model))(norms).squeeze(-1) 55 forces = -network_derivs * drs / norms 56 return energies, forces 57 58 59def loss_fn(energies, forces, predicted_energies, predicted_forces): 60 return ( 61 mse_loss(energies, predicted_energies) 62 + 0.01 * mse_loss(forces, predicted_forces) / 3 63 ) 64 65 66optimiser = torch.optim.Adam(model.parameters(), lr=1e-3) 67 68for epoch in range(400): 69 optimiser.zero_grad() 70 energies, forces = make_prediction(model, drs) 71 loss = loss_fn(training_energies, training_forces, energies, forces) 72 loss.backward(retain_graph=True) 73 optimiser.step() 74 75 if epoch % 20 == 0: 76 print(loss.cpu().item()) 77