Pytorch layer 10 times slower then pure python implementation

I am trying to implement Stillinger Weber potential as a pytorch layer. I understand that such layer might not be “ideal” for ML frameworks like Pytorch but in any case I will expect it perform no worse then pure python or python/numpy implementation of the same.

When I run the same layer as pytorch nn.Module inherited class i get single run to be about 27.977 +- 4.8341 sec (avg over 10 runs).
Where as a simple python class executes the exact same code in 2.4446 +- 0.052723 sec (avg over 10 runs)

Complete scripts can be found at GitHub - ipcamit/temp_4_pytorch: Temporary git repo for pytorch forum scripts

What am I doing wrong?

Below are the implemented functions for pytorch

# =============================================================================
# StillingerWeber Model
# =============================================================================
# SW subroutines (PyTorch gets fussy if function are part of class)
@torch.jit.script
def calc_d_sw2(A, B, p, q, sigma, cutoff, rij):
    if rij < cutoff:
        sig_r = sigma / rij
        one_by_delta_r = 1.0 / (rij - cutoff)
        Bpq = (B * sig_r ** p - sig_r ** q)
        exp_sigma = torch.exp(sigma * one_by_delta_r )
        E2 = A * Bpq * exp_sigma 
        F = (q * sig_r ** (q + 1)) - p * B * sig_r ** (p + 1) - Bpq * (sigma * one_by_delta_r) ** 2
        F = F * (1./sigma) * A * exp_sigma
    else:
        return torch.tensor(0.0), torch.tensor(0.0) 
    return E2, F


@torch.jit.script
def calc_d_sw3(lam, cos_beta0, gamma_ij, gamma_ik,
                 cutoff_ij, cutoff_ik, cutoff_jk, rij, rik, rjk, dE3_dr):
    if ((rij > cutoff_ij) or 
        (rik > cutoff_ik) or 
        (rjk > cutoff_jk)):
        dE3_dr[0] = 0.0; dE3_dr[1] = 0.0; dE3_dr[2] = 0.0
        return torch.tensor(0.0)
    else: 
        cos_beta_ikj = (rij**2 + rik**2 - rjk**2) / (2 * rij * rik)
        cos_diff = cos_beta_ikj - cos_beta0

        exp_ij_ik = torch.exp(gamma_ij/(rij - cutoff_ij) + gamma_ik/(rik - cutoff_ik))

        dij = - gamma_ij/(rij - cutoff_ij)**2
        dik = - gamma_ik/(rik - cutoff_ik)**2

        E3 = lam * exp_ij_ik * cos_diff ** 2

        dcos_drij = (rij**2 - rik**2 + rjk**2) / (2 * rij**2 * rik)
        dcos_drik = (rik**2 - rij**2 + rjk**2) / (2 * rik**2 * rij)
        dcos_drjk = (- rjk) / (rij * rik)

        dE3_dr[0] = lam * cos_diff * exp_ij_ik * (dij * cos_diff + 2 * dcos_drij)
        dE3_dr[1] = lam * cos_diff * exp_ij_ik * (dik * cos_diff + 2 * dcos_drik)
        dE3_dr[2] = lam * cos_diff * exp_ij_ik * 2 * dcos_drjk
    return E3


@torch.jit.script
def energy_and_forces(
    nl: List[List[int]],
    elements_nl: List[List[int]],
    coords_all,
    A: List[torch.Tensor],
    B: List[torch.Tensor],
    p: List[torch.Tensor],
    q: List[torch.Tensor],
    sigma: List[torch.Tensor],
    gamma: List[torch.Tensor],
    cutoff: List[torch.Tensor],
    lam: List[torch.Tensor],
    cos_beta0: List[torch.Tensor],
    cutoff_jk: List[torch.Tensor]
    ):
    """
    Calculatd Energy for a given list of coordiates, assuming first coordinate
    to be of query atom i, and remaining in the list to be neighbours. 
    """
    energy = torch.tensor(0.0)
    F2 = torch.tensor(0.0)
    F3 = torch.zeros(3)
    E2 = torch.tensor(0.0)
    E3 = torch.tensor(0.0)
    gamma_ij = torch.tensor(0.0)
    gamma_ik = torch.tensor(0.0)
    cutoff_ij = torch.tensor(0.0)
    cutoff_ik = torch.tensor(0.0)
    xyz_i = torch.zeros(3)
    xyz_j = torch.zeros(3)
    xyz_k = torch.zeros(3)
    rij = torch.zeros(3)
    rik = torch.zeros(3)
    rjk = torch.zeros(3)
    F = torch.zeros_like(coords_all)
    F_comp = torch.zeros(3)
    for i, (nli, elements) in enumerate(zip(nl,elements_nl)):
        num_elem = len(nli)
        xyz_i = coords_all[nli[0]]
        elem_i = elements[0]
        for j in range(1, num_elem):
            elem_j = elements[j]
            xyz_j = coords_all[nli[j]]
            rij = xyz_j - xyz_i
            norm_rij = torch.norm(rij)
            # if elem_i == elem_j:
            ij_sum = elem_i + elem_j

            E2, F2 = calc_d_sw2(A[ij_sum], B[ij_sum], p[ij_sum], q[ij_sum], sigma[ij_sum], cutoff[ij_sum], norm_rij)
            energy = 0.5 * E2
            F_comp =  0.5 * F2/norm_rij * rij
            F[i,:] = F[i,:] + F_comp
            F[nli[j], :] = F[nli[j],:] - F_comp
            gamma_ij = gamma[ij_sum]
            cutoff_ij = cutoff[ij_sum]

            for k in range(j + 1, num_elem):
                elem_k = elements[k]
                if (elem_i != elem_j) and \
                   (elem_j == elem_k):
                    ijk_sum = 2 + -1 * (elem_i + elem_j + elem_k)
                    ik_sum = elem_i + elem_k
                    xyz_k = coords_all[nli[k]]
                    rik = xyz_k - xyz_i
                    norm_rik = torch.norm(rik)
                    rjk = xyz_k - xyz_j
                    norm_rjk = torch.norm(rjk)
                    gamma_ik = gamma[ik_sum]
                    cutoff_ik = cutoff[ik_sum]
                    E3 =  calc_d_sw3(lam[ijk_sum], cos_beta0[ijk_sum], gamma_ij, gamma_ik,
                                    cutoff_ij, cutoff_ik, cutoff_jk[ijk_sum], norm_rij, norm_rik, norm_rjk, F3)
                    energy = energy + E3
                    F_comp[:] = F3[0]/norm_rij * rij
                    F[i, :] = F[i, :] + F_comp
                    F[nli[j], :] = F[nli[j], :] - F_comp
                    F_comp[:] = F3[1]/norm_rik * rik
                    F[i, :] = F[i, :] + F_comp
                    F[nli[k], :] = F[nli[k], :] - F_comp
                    F_comp[:] = F3[2]/norm_rjk * rjk
                    F[nli[j], :] = F[nli[j], :] + F_comp
                    F[nli[k], :] = F[nli[k], :] - F_comp
    return energy, F


# =============================================================================
class StillingerWeberLayer(nn.Module):
    """
    Stillinger-Weber single species layer for Mo and S atom for use in PyTorch model 
    """
    def __init__(self):
        super().__init__()
        self.elements = elements
...
    def forward(self, 
                elements: List[List[int]],
                coords: torch.Tensor,
                nl: List[List[int]],
                padding: List[int]
                ):
        total_conf_energy = torch.tensor(0.0)
        n_atom = len(nl)
        F = torch.zeros((n_atom, 3))
        total_conf_energy, forces = energy_and_forces(nl, elements, coords, self.A, 
                                            self.B, self.p, self.q, self.sigma, self.gamma,
                                            self.cutoff, self.lam, self.cos_beta0, self.cutoff_jk)
        F[:n_atom] = forces[:n_atom]

        if len(padding) != 0:
            pad_forces = forces[n_atom:]
            n_padding = len(pad_forces)

            if n_atom < n_padding:
                for i in range(n_atom):
                    indices = torch.where(padding == i)
                    F[i] = F[i] + torch.sum(pad_forces[indices], 0)
            else:
                for f, org_index in zip(pad_forces, padding):
                    F[org_index] = F[org_index] + f
        return total_conf_energy, F
# ==========================================================================================================

The pure python version was implemented exactly same

# =============================================================================
# StillingerWeber Model
# =============================================================================
# SW subroutines

def calc_d_sw2(A, B, p, q, sigma, cutoff, rij):
    if rij < cutoff:
        sig_r = sigma / rij
        one_by_delta_r = 1.0 / (rij - cutoff)
        Bpq = (B * sig_r ** p - sig_r ** q)
        exp_sigma = np.exp(sigma * one_by_delta_r )
        E2 = A * Bpq * exp_sigma 
        F = (q * sig_r ** (q + 1)) - p * B * sig_r ** (p + 1) - Bpq * (sigma * one_by_delta_r) ** 2
        F = F * (1./sigma) * A * exp_sigma
    else:
        return 0.0, 0.0 
    return E2, F


def calc_d_sw3(lam, cos_beta0, gamma_ij, gamma_ik,
                 cutoff_ij, cutoff_ik, cutoff_jk, rij, rik, rjk, dE3_dr):
    if ((rij > cutoff_ij) or 
        (rik > cutoff_ik) or 
        (rjk > cutoff_jk)):
        dE3_dr[0] = 0.0; dE3_dr[1] = 0.0; dE3_dr[2] = 0.0
        return 0.0
    else: 
        cos_beta_ikj = (rij**2 + rik**2 - rjk**2) / (2 * rij * rik)
        cos_diff = cos_beta_ikj - cos_beta0

        exp_ij_ik = np.exp(gamma_ij/(rij - cutoff_ij) + gamma_ik/(rik - cutoff_ik))

        dij = - gamma_ij/(rij - cutoff_ij)**2
        dik = - gamma_ik/(rik - cutoff_ik)**2

        E3 = lam * exp_ij_ik * cos_diff ** 2

        dcos_drij = (rij**2 - rik**2 + rjk**2) / (2 * rij**2 * rik)
        dcos_drik = (rik**2 - rij**2 + rjk**2) / (2 * rik**2 * rij)
        dcos_drjk = (- rjk) / (rij * rik)

        dE3_dr[0] = lam * cos_diff * exp_ij_ik * (dij * cos_diff + 2 * dcos_drij)
        dE3_dr[1] = lam * cos_diff * exp_ij_ik * (dik * cos_diff + 2 * dcos_drik)
        dE3_dr[2] = lam * cos_diff * exp_ij_ik * 2 * dcos_drjk
    return E3


def energy_and_forces(nl, elements_nl, coords_all, A, B, p, q, sigma, gamma, 
                        cutoff, lam, cos_beta0, cutoff_jk):
    """
    Calculatd Energy for a given list of coordiates, assuming first coordinate
    to be of query atom i, and remaining in the list to be neighbours. 
    """
    energy = 0.0
    F2 = 0.0
    F3 = np.zeros(3)
    E2 = 0.0
    E3 = 0.0
    gamma_ij = 0.0
    gamma_ik = 0.0
    cutoff_ij = 0.0
    cutoff_ik = 0.0
    xyz_i = np.zeros(3)
    xyz_j = np.zeros(3)
    xyz_k = np.zeros(3)
    rij = np.zeros(3)
    rik = np.zeros(3)
    rjk = np.zeros(3)
    F = np.zeros_like(coords_all)
    F_comp = np.zeros(3)
    for i, (nli, elements) in enumerate(zip(nl,elements_nl)):
        num_elem = len(nli)
        xyz_i = coords_all[nli[0]]
        elem_i = elements[0]
        for j in range(1, num_elem):
            elem_j = elements[j]
            xyz_j = coords_all[nli[j]]
            rij = xyz_j - xyz_i
            norm_rij = np.sqrt(rij[0]**2 + rij[1]**2 + rij[2]**2)
            # if elem_i == elem_j:
            ij_sum = elem_i + elem_j

            E2, F2 = calc_d_sw2(A[ij_sum], B[ij_sum], p[ij_sum], q[ij_sum], sigma[ij_sum], cutoff[ij_sum], norm_rij)
            energy = 0.5 * E2
            F_comp =  0.5 * F2/norm_rij * rij
            F[i,:] = F[i,:] + F_comp
            F[nli[j], :] = F[nli[j],:] - F_comp
            gamma_ij = gamma[ij_sum]
            cutoff_ij = cutoff[ij_sum]

            for k in range(j + 1, num_elem):
                elem_k = elements[k]
                if (elem_i != elem_j) and \
                   (elem_j == elem_k):
                    ijk_sum = 2 + -1 * (elem_i + elem_j + elem_k)
                    ik_sum = elem_i + elem_k
                    xyz_k = coords_all[nli[k]]
                    rik = xyz_k - xyz_i
                    norm_rik = np.sqrt(rik[0]**2 + rik[1]**2 + rik[2]**2)
                    rjk = xyz_k - xyz_j
                    norm_rjk = np.sqrt(rjk[0]**2 + rjk[1]**2 + rjk[2]**2)
                    gamma_ik = gamma[ik_sum]
                    cutoff_ik = cutoff[ik_sum]
                    E3 =  calc_d_sw3(lam[ijk_sum], cos_beta0[ijk_sum], gamma_ij, gamma_ik,
                                    cutoff_ij, cutoff_ik, cutoff_jk[ijk_sum], norm_rij, norm_rik, norm_rjk, F3)
                    energy = energy + E3
                    F_comp[:] = F3[0]/norm_rij * rij
                    F[i, :] = F[i, :] + F_comp
                    F[nli[j], :] = F[nli[j], :] - F_comp
                    F_comp[:] = F3[1]/norm_rik * rik
                    F[i, :] = F[i, :] + F_comp
                    F[nli[k], :] = F[nli[k], :] - F_comp
                    F_comp[:] = F3[2]/norm_rjk * rjk
                    F[nli[j], :] = F[nli[j], :] + F_comp
                    F[nli[k], :] = F[nli[k], :] - F_comp
    return energy, F


# =============================================================================
class StillingerWeberLayer():
    """
    Stillinger-Weber single species layer for Mo and S atom for use in PyTorch model
    """
    def __init__(self):
        super().__init__()
        self.elements = elements
...
    def __call__(self, elements, coords, nl, padding):
        n_atom = len(nl)
        F = np.zeros((n_atom, 3))
        total_conf_energy, forces = energy_and_forces(nl, elements, coords, self.A, 
                                            self.B, self.p, self.q, self.sigma, self.gamma,
                                            self.cutoff, self.lam, self.cos_beta0, self.cutoff_jk)
        F[:n_atom] = forces[:n_atom]

        if len(padding) != 0:
            pad_forces = forces[n_atom:]
            n_padding = len(pad_forces)

            if n_atom < n_padding:
                for i in range(n_atom):
                    indices = np.where(padding == i)
                    F[i] = F[i] + np.sum(pad_forces[indices], axis=0)
            else:
                for f, org_index in zip(pad_forces, padding):
                    F[org_index] = F[org_index] + f
        return total_conf_energy, F
# ==========================================================================================================

So the Thomas rule of thumb™ is that you need your tensor sizes to be at least three figures for PyTorch to work well.
The reason for this is that creating tensors carries an overhead and this is likely killing performance in your code when you have lots of <= 3-element tensors. (This is also why torch.jit.script doesn’t really help you here.) It is not entirely clear to me if you compare autograd-enabled PyTorch to NumPy, this would further increase the administrative overhead for each operation. Things will get better if and when you can leverage the JIT fusers as these then operate on “bare metal” internally instead of creating one tensor after another, but these don’t cover all of PyTorch operations and one would need to analyse a bit how to make most of them. The two scripted functions look like they might be very much in scope for them at first sight.

If you can make it into something one can copy and paste to run (so add a self-contained call), it might be interesting to profile it and see if the fusers can help you.

Best regards

Thomas

P.S.: This isn’t the place where you might expect it to be recommended, but I hear that cases similar to this are said to be a nice fit for Julia, which will compile to “bare metal” after differentiation.

Well I assumed there will be performance loss, but wasn’t expecting this much. I mean same disadvantages of smaller array will plague numpy as well. Is there any benchmarks or treatment of comparison of numpy vs pytorch with array size?

I am in the process of porting it to TF and Julia/Flux for comparison. If it is due to small tensor size, then similar loss of performance might get observed in TF as well?

Well, so Tensors in inference mode (in particular w/o gradients) should be not that much slower than numpy. If you find a factor 10 in a reproducible comparison, it probably is worth filing a bug.
I don’t use TF much and so I don’t know if it creates intermediate tensors or can avoid it. If anything, I would look into fusion or create a custom cuda kernel, but that is me.

Best regards

Thomas

Ok. I would try to reproduce it in simpler examples and probably post it on stackoverflow before filing bug report.