xref: /aosp_15_r20/external/pytorch/functorch/examples/lennard_jones/lennard_jones.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# This example was adapated from https://github.com/muhrin/milad
2# It is licensed under the GLPv3 license. You can find a copy of it
3# here: https://www.gnu.org/licenses/gpl-3.0.en.html .
4
5import torch
6from torch import nn
7from torch.func import jacrev, vmap
8from torch.nn.functional import mse_loss
9
10
11sigma = 0.5
12epsilon = 4.0
13
14
15def lennard_jones(r):
16    return epsilon * ((sigma / r) ** 12 - (sigma / r) ** 6)
17
18
19def lennard_jones_force(r):
20    """Get magnitude of LJ force"""
21    return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7))
22
23
24training_size = 1000
25r = torch.linspace(0.5, 2 * sigma, steps=training_size, requires_grad=True)
26
27# Create a bunch of vectors that point along positive-x
28drs = torch.outer(r, torch.tensor([1.0, 0, 0]))
29norms = torch.norm(drs, dim=1).reshape(-1, 1)
30# Create training energies
31training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1)
32# Create forces with random direction vectors
33training_forces = torch.stack(
34    [force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)]
35)
36
37model = nn.Sequential(
38    nn.Linear(1, 16),
39    nn.Tanh(),
40    nn.Linear(16, 16),
41    nn.Tanh(),
42    nn.Linear(16, 16),
43    nn.Tanh(),
44    nn.Linear(16, 16),
45    nn.Tanh(),
46    nn.Linear(16, 1),
47)
48
49
50def make_prediction(model, drs):
51    norms = torch.norm(drs, dim=1).reshape(-1, 1)
52    energies = model(norms)
53
54    network_derivs = vmap(jacrev(model))(norms).squeeze(-1)
55    forces = -network_derivs * drs / norms
56    return energies, forces
57
58
59def loss_fn(energies, forces, predicted_energies, predicted_forces):
60    return (
61        mse_loss(energies, predicted_energies)
62        + 0.01 * mse_loss(forces, predicted_forces) / 3
63    )
64
65
66optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
67
68for epoch in range(400):
69    optimiser.zero_grad()
70    energies, forces = make_prediction(model, drs)
71    loss = loss_fn(training_energies, training_forces, energies, forces)
72    loss.backward(retain_graph=True)
73    optimiser.step()
74
75    if epoch % 20 == 0:
76        print(loss.cpu().item())
77