Backprop trough Fully connected network and sparse matrix


(Rpfeynman) #1

Hello everyone! I have a little trouble understanding how to perform backpropagation in the presence of sparse matrices.

I have the following code that works well and for which I’m able to optimise the variable y

...
A = torch.sparse.FloatTensor(X, a, torch.Size([649*128,128**2])).cuda()

y =  torch.randn(torch.Size([128**2,1]), requires_grad=True, device=device)
torch.nn.init.normal_(y,mean=0,std=1)

optimizer = torch.optim.SGD([y], lr=0.001)
costfun = nn.MSELoss()

for t in range(1000000):
   
    # Forward pass: Compute predicted y by passing x to the model
    x_pred = torch.matmul(A,y)
    
    # Get loss
    loss = costfun(x_pred, x_data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    loss.backward()
    print(y.grad)
    optimizer.step()

However, if I use a Fully connected network before the matrix multiplication, as follows

...
A = torch.sparse.FloatTensor(X, a, torch.Size([649*128,128**2])).cuda()

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(649*128, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128*128)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net1 = Net().cuda()

optimizer = torch.optim.SGD(net1.parameters(), lr=0.001)
costfun = nn.MSELoss()

for t in range(1000000):
   
    # Forward pass: Compute predicted y by passing x to the model
    y = net1(torch.reshape(x_data[0], (1,-1)))

    x_pred = torch.matmul(A,y)
    
    # Get loss
    loss = costfun(x_pred, x_data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    loss.backward()
    print(y.grad)
    optimizer.step()

I get the following error

RuntimeError: Expected object of backend CUDA but got backend SparseCUDA for argument #2 'mat2'

Any idea of what I may be doing wrong?