Difficulty speeding up code: training a MLP using a complicated many-to-one nonlinear function

A quick summary:

My goal is to figure out if a specific complicated nonlinear function can be used to replace individual neurons in a neural network. Ideally, I’d like to show that I can train on MNIST pictures of numbers. I’ve made an attempt in pytorch, but it is way too slow, mostly because I haven’t been able to figure out how to do batches and neurons in parallel, and I’m looking for ideas or approaches to be able to dramatically speed up the process.

What I’m trying to do:

A typical neuron in a neural network is defined as performing a dot-product and then performing a nonlinear function on the output of that dot product, f(x dot w).

Instead of f(x dot w), I am considering a many-to-one nonlinear function that is a more general nonlinear function of x and w, i.e. f(x, w). The nonlinear function f(x, w) takes a 1D array of X and a 1D array of W, and returns a single output. I have numpy code which performs this calculation. It models a real physical system and requires a recursive series of integrals to calculate. In a previous question, I learned that I can convert my numpy code to pytorch functions and that pytorch should be able to automatically do the back-propagation gradents for me.

Now I have pytorch code describing the nonlinear function f(x,w). I want to demonstrate that I can use it to learn on pictures of numbers, so I have reduced MNIST numbers to 10x10 pixel images and have set up a MLP-inspired network with 100 inputs, a hidden size of 100, and 10 outputs.

Explaining this MLP-inspired network in greater detail:

The first layer consists of 100 “neurons” where the typical neuron is replaced with my nonlinear function f(x, w). Each of the 100 “neurons” accepts the input for X and has a different set of weights w. Finally, the outputs of these 100 neurons are passed on to the next layer. The next layer is simply 10 neurons, the output of each are used for identifying each of the 10 digits.

Here is a snippet of the code for the forward pass of the network:

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.weights1 = nn.Parameter(torch.randn(input_size, hidden_size))  # weights of size (input_size, hidden_size)
        self.weights2 = nn.Parameter(torch.randn(hidden_size, num_classes)) # weights of size (hidden_size, num_classes)

    def forward(self, x):
        hidden_outputs = []
        for neuron_weights in self.weights1.T:  # loop over each neuron's weights in the first layer
            output_value = input_output_nonlinearity_torch(x.squeeze(), neuron_weights, C_total = C_total, readoutStrength = readoutStrength)
            output_value = F.relu(output_value)  # apply a relu activation function
            hidden_outputs.append(output_value)

       final_outputs = []
       for neuron_weights in self.weights2.T:  # loop over each neuron's weights in the second layer
           output_value = input_output_nonlinearity_torch(hidden_outputs, neuron_weights, C_total = C_total, readoutStrength = readoutStrength)
           final_outputs.append(output_value)

       final_outputs = [output.unsqueeze(0) for output in final_outputs]
       final_output = torch.stack(final_outputs, dim=0)
       final_output = final_output.t()
       return final_output

The problem is that each iteration of training this network takes about 90 minutes, using only 1% of the MNIST dataset. So I really need to figure out how to make this faster.

input_output_nonlinearity() is my nonlinear function f(x, w). As you can see from the code, I am finding the output of each “neuron” in the network individually, by for-looping through each of the weights. In principle, though, each neuron is fully independent, and could be run in parallel.

So, one approach would be to further vectorize my code. But, I have not been able to figure out a simple vectorized way that I could pass a matrix X and a matrix W to f(x, w), such that I get a set of outputs for different neurons and sets of input data (and I’ll give the full code at the end). This to me seems to be fairly challenging to implement (but I’m sure is possible).

Another idea is that maybe there’s another way of telling pytorch that these calculations are completely independent, so it can do some parallel processing under the hood? Any ideas if that is possible, or do I have to brute-force my way via a fully vectorized solution?

Here is the code for the whole thing. My apologies for the length, but I want to give the full code so that any speed inefficiencies can be properly identified.

Here is the code for the training of the network:

from neural_network_pytorch_dot_product import *
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.utils.data as data
import numpy
import torch.nn.functional as F

numpoints_for_greens_integrals = 100
C_total = .1
readoutStrength = 1/C_total

def reduce_dataset(dataloader, fraction):
    num_samples = int(len(dataloader.dataset) * fraction)
    indices = torch.randperm(len(dataloader.dataset))[:num_samples]
    new_dataset = data.Subset(dataloader.dataset, indices)
    new_dataloader = data.DataLoader(new_dataset, batch_size=dataloader.batch_size, shuffle=True)
    return new_dataloader


class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.weights1 = nn.Parameter(torch.randn(input_size, hidden_size))  # weights of size (input_size, hidden_size)
        self.weights2 = nn.Parameter(torch.randn(hidden_size, num_classes)) # weights of size (hidden_size, num_classes)

    def forward(self, x):
        hidden_outputs = []
        for neuron_weights in self.weights1.T:  # loop over each neuron's weights in the first layer
            output_value = input_output_nonlinearity_torch(x.squeeze(), neuron_weights, C_total = C_total, readoutStrength = readoutStrength)
            output_value = F.relu(output_value)  # apply a relu activation function
            hidden_outputs.append(output_value)

        final_outputs = []
        for neuron_weights in self.weights2.T:  # loop over each neuron's weights in the second layer
            output_value = input_output_nonlinearity_torch(hidden_outputs, neuron_weights, C_total = C_total, readoutStrength = readoutStrength)
            final_outputs.append(output_value)

        final_outputs = [output.unsqueeze(0) for output in final_outputs]
        final_output = torch.stack(final_outputs, dim=0)
        final_output = final_output.t()
        return final_output

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# hyperparameters

input_size = 100
hidden_size = 100
num_classes = 10
num_epochs = 10
batch_size = 1
learning_rate = 0.001

pixelX = 10

# MNIST dataset (28x28 images!)
# train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
# test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

# Compressed images to (10x10 images!)
# Define a new transform to resize the images
resize_transform = transforms.Resize((10, 10))

# MNIST dataset with resize transform
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    resize_transform
]), download=True)

test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    resize_transform
]))


train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Update train loader and test loader
train_loader = reduce_dataset(train_loader, 0.01) # reduce to 1% of original size
test_loader = reduce_dataset(test_loader, 0.01) # reduce to 1% of original size


# instantiate the MLP
model = MLP(input_size, hidden_size, num_classes).to(device)

# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# train the model
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, pixelX*pixelX).to(device)
        labels = labels.to(device)
        
        # forward pass
        print('beginning forward pass')
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

# test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, pixelX*pixelX).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on test images: {} %'.format(100 * correct / total))

And here is the code describing the nonlinear function f(x, w):

import torch
from torch import nn, optim
import numpy as np
from scipy import special

def readinKernel_torch(wdummy, z, Ec, Ep, kval=1, ic = 10**-9*torch.sqrt(3.14/(8*torch.log(torch.tensor([2.]))))*2*3.14*3*10**6, od = 10000, gamma = 2*3.14*18*10**9/(2*3.14*3*10**6), extra = 1, Np = (8*torch.log(torch.tensor([2.]))/(torch.pow(torch.tensor([10.])**-9*(2*3.14*3*10**6),2)*torch.tensor(torch.pi))).pow(0.25)):
    return Ec * kval * torch.special.bessel_j0(2* Ec * kval * torch.sqrt(torch.ger(z, (1 - wdummy))))*torch.exp(-1*1j*Ec**2*extra*repmat_torch(wdummy, len(z))*kval**2*gamma/od)* torch.sqrt(ic)*Ep*Np

def readoutKernel_torch(zdummy, z, B_in, Ec, kval=1):
    return  (Ec * kval *
            steep_sigmoid_torch(torch.sub(z.repeat(zdummy.size(0), 1).T, zdummy), 50) *
            1 / torch.sqrt(torch.clamp(torch.sub(z.repeat(zdummy.size(0), 1).T, zdummy), min=1e-10)) *
            torch.special.bessel_j1(2 * Ec * kval * torch.sqrt(torch.clamp(torch.sub(z.repeat(zdummy.size(0), 1).T, zdummy), min=1e-10))) *
            repmat_torch(B_in, len(z)))

def final_readoutKernel_torch(zdummy, w, Ec, B_in, kval=1):
    #  This is the same Kernel as the readin kernel, but with K(z, w) switched to K(w, z).    
    out = Ec * kval * torch.special.bessel_j0(2* Ec * kval * torch.sqrt(torch.ger(w, (1 - zdummy))))*repmat_torch(B_in, len(w))
    return out

def repmat_torch(arr, num_reps):
    return arr.view(1, -1).repeat(num_reps, 1)

def steep_sigmoid_torch(x, k=50):
    return 1.0 / (1.0 + torch.exp(-k*x))

def complex_trap_torch(z, xvals, axisdim):
    real_values = torch.real(z)
    imaginary_values = torch.imag(z)
    real_integral = torch.trapz(real_values, x=xvals, dim=axisdim)
    imaginary_integral = torch.trapz(imaginary_values, x=xvals, dim=axisdim)
    complexout = real_integral + 1j * imaginary_integral
    return complexout

def spinwave_recursive_calculation_torch(B_in, z_values, w_values, Ec, Ep, c_per_mode = 1):

    readin_values = readinKernel_torch(w_values, z_values, Ec, Ep, kval = c_per_mode)
    readout_values = readoutKernel_torch(z_values, z_values, B_in, Ec, kval = c_per_mode)

    readin_integrals = complex_trap_torch(readin_values, xvals=w_values, axisdim=1)
    readout_integrals = complex_trap_torch(readout_values, xvals=z_values, axisdim=1) 

    spinwave = readin_integrals - readout_integrals + B_in
    return spinwave



def input_output_nonlinearity_torch(x, w, numpoints = 100, C_total = 1, readoutStrength = 1):
    z_values =  torch.linspace(1e-10, 1-1e-10, numpoints)
    w_values =  torch.linspace(1e-10, 1-1e-10, numpoints)

    Bin = torch.zeros(len(z_values), dtype=torch.complex128)
    BoutMatrix = repmat_torch(Bin, len(w))

    c_per_mode = C_total/len(w)

    for i in range(len(w)):
        E_c_val = w[i]
        E_p_val = x[i]
        # print('E_p_val', E_p_val)
        # print('x', x)
        BoutMatrix[i, :] = spinwave_recursive_calculation_torch(Bin, z_values, w_values, E_c_val, E_p_val, c_per_mode)
        Bin = BoutMatrix[i, :]
    Bout = BoutMatrix[-1, :]

    output_Efield_w_z = final_readoutKernel_torch(z_values, w_values, readoutStrength, Bout, kval=1)
    output_Efield_w = torch.trapz(torch.real(output_Efield_w_z), x = z_values, dim =1)
    output_Efield = torch.trapz(torch.real(output_Efield_w), x = w_values, dim = 0)

    return output_Efield

(Again, apologies for the long code, but the key difficulty of my problem is exactly that it’s hard to vectorize such a complicated thing. If I tried writing a simpler example, then it likely will be more clear how to vectorize it and a User’s answer to that simplified question won’t help me.)