Adding new parameters for training

Hi, recently I was trying to reimplement in PyTorch some paper where they implement new way of using kernels: https://www.sciencedirect.com/science/article/abs/pii/S1051200419301873 and before this project, I have only been doing with the kernels and layers supported by Pytorch. In this case I am supposed to add new parameters that will be trained along with the rest of the regular layers, but I feel that I am not doing something properly since parameters are not being updated.

class PointPlaneResnet(nn.Module):
    ''' PointPlaneNet-based encoder network with ResNet blocks. \n
    Args:
        c_dim (int): dimension of latent code c,  defined by config's model.c_dim = 512
        dim (int): input points dimension, in our case 
        hidden_dim (int): hidden dimension of the network
        k (int) : number of neighbours in grouping layer
        channels (int) : number of planes/channels
    '''

    def __init__(self, 
                c_dim=128, 
                dim=3, 
                hidden_dim=128, 
                k = 40,
                channels = 3):

        super().__init__()
        self.c_dim = c_dim
        self.k = k #grouping layer
        self.channels = channels

        hidden_dim = 512
        ##Parameters are Tensor subclasses, that have a very special property when used 
        # with Module s - when they’re 
        # assigned as Module attributes they are automatically added 
        # to the list of its parameters, and will appear e.g. in parameters() iterator. 
        self.plane_weights = torch.nn.Parameter(torch.randn(channels, 4).cuda())

        torch.nn.init.xavier_normal_(self.plane_weights)
        

        # self.fc_pos = nn.Linear(dim, 2*hidden_dim)
        self.plane_conv = convolution
        
        self.mlp = MLP(channels = channels)
        self.block_0 = ResnetBlockFC(2*hidden_dim, hidden_dim)
        self.block_1 = ResnetBlockFC(2*hidden_dim, hidden_dim)
        self.block_2 = ResnetBlockFC(2*hidden_dim, hidden_dim)
        self.block_3 = ResnetBlockFC(2*hidden_dim, hidden_dim)
        self.block_4 = ResnetBlockFC(2*hidden_dim, hidden_dim)
        self.fc_c = nn.Linear(hidden_dim, c_dim)

        self.actvn = nn.ReLU()
        self.pool = maxpool

    def forward(self, p):
        batch_size, T, D = p.size()

        # output size: B x T X F
        # net = self.fc_pos(p)

        # print(f'Weight planes {self.plane_weights}')

        net_batch = []
        for i in range(batch_size):
            # print(f'Weight planes {self.plane_weights}')

            net_sample = self.plane_conv(p[i,:,:], self.k, self.plane_weights, self.channels)
            net_batch.append(net_sample)
            # print(f'net_sample is {net_sample}')

        # print(f'net_batch: {net_batch}')
        net = torch.stack(net_batch)
        net = self.mlp(net)


        net = self.block_0(net)
        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
        net = torch.cat([net, pooled], dim=2)

        net = self.block_1(net)
        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
        net = torch.cat([net, pooled], dim=2)

        net = self.block_2(net)
        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
        net = torch.cat([net, pooled], dim=2)

        net = self.block_3(net)
        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
        net = torch.cat([net, pooled], dim=2)

        net = self.block_4(net)

        # Recude to  B x F
        net = self.pool(net, dim=1)

        c = self.fc_c(self.actvn(net))

        return c

As you can see here, I added self.plane_weights as new parameters which will be used in plane_conv/convolution function.

But when I start training, I don’t see plane_weights being updated:

from im2mesh.encoder import point_plane_net
new_model = point_plane_net.PointPlaneResnet(k = 5).cuda()

#random input
input = torch.randn(10,10, 3).cuda()
labels = torch.randn(10,128).cuda() * 1000
input.size()

optimizer = optim.SGD(new_model.parameters(), lr=0.001, momentum=0.9)
optimizer.zero_grad()
criterion = nn.BCELoss()

for i in range(5):
    a = list(new_model.parameters())[0]
    print(a)
    outputs = new_model(input)
    outputs.shape

    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    b = list(new_model.parameters())[0]
    print(b)

    print(torch.equal(a.data, b.data))

Can someone see from here if I am doing something wrong, shouldn’t autograd be responsible for updating the weights or should I maybe switch to different training loss?

1 Like

I’m not sure how plane_conv is defined, but this dummy example works:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = torch.nn.Parameter(torch.randn(1, 1, 3, 3))

    def forward(self, x):
        x = F.conv2d(x, self.weights)
        return x

    
model = MyModel()
w0 = model.weights.clone()

optimizer = torch.optim.SGD(model.parameters(), lr=10.)

x = torch.randn(1, 1, 5, 5)
out = model(x)
out.mean().backward()

optimizer.step()

w1 = model.weights.clone()
print(w1 - w0)

(I’ve removed all undefined methods from your model and just used the defined parameter)

1 Like

Thing is plane_conv does only multiplying plane_weights by some other vector and returns something that might not have established backprop for it (I think). Do I need to add definition of a backprop for it. Here is the code:

def convolution(points, k, weights, channels):
    """creates features for every point
    Parameters
        ----------
        points   : torch.tensor
                   set of 3d points [N*3]
        k        : int
                   number of nearest neighbours
        weights  : torch.tensor
                   set of weights [channels*4]
        channels : int
                   number of channels

        Returns
        -------
        tensor [N*channels] - for every point "channels" features
    
    """
    number_points = points.shape[0]
    # array_features = torch.zeros([number_points, channels], dtype=torch.int32)
    array_features = torch.zeros([number_points, channels], dtype=torch.float).cuda()
    for i in range(number_points):
        dist = torch.norm(points - points[i], dim=1, p=None).cuda()
        #For PyTorch version 1.0.0  https://pytorch.org/docs/1.0.0/torch.html?highlight=topk#torch.topk
        id_neighbours = dist.topk(k+1, largest=False)[1]
        # id_neighbours = torch.Tensor(dist.topk(k+1, largest=False)).indices
        # array_features[i] = torch.tensor([final_kernel(
        #     points, i, weights, channel, k, id_neighbours) for channel in np.arange(0, channels, 1)])
        array_features[i] = torch.tensor([final_kernel(
            points, i, weights, channel, k, id_neighbours) for channel in np.arange(0, channels, 1)],dtype=torch.float)

    return array_features

This might break the computation graph, since you are recreating a new tensor array_features.
Could you instead store the outputs of final_kernel in a list and use torch.stack(array_features) after the loop to create a tensor?

I tried doing this:

def convolution(points, k, weights, channels):
    """creates features for every point
    Parameters
        ----------
        points   : torch.tensor
                   set of 3d points [N*3]
        k        : int
                   number of nearest neighbours
        weights  : torch.tensor
                   set of weights [channels*4]
        channels : int
                   number of channels

        Returns
        -------
        tensor [N*channels] - for every point "channels" features
    
    """
    number_points = points.shape[0]
    # array_features = torch.zeros([number_points, channels], dtype=torch.int32)
    # array_features = torch.zeros([number_points, channels], dtype=torch.float).cuda()
    array_features = []
    for i in range(number_points):
        dist = torch.norm(points - points[i], dim=1, p=None).cuda()
        #For PyTorch version 1.0.0  https://pytorch.org/docs/1.0.0/torch.html?highlight=topk#torch.topk
        id_neighbours = dist.topk(k+1, largest=False)[1]
        array_feature = torch.tensor([final_kernel(
            points, i, weights, channel, k, id_neighbours) for channel in np.arange(0, channels, 1)],dtype=torch.float).cuda()
        # array_feature = [final_kernel(
        #     points, i, weights, channel, k, id_neighbours) for channel in np.arange(0, channels, 1)]
        array_features.append(array_feature)
    array_features = torch.stack(array_features).cuda()

    return array_features

And you can see I used here stacking after loop and yet it still fails in training using
code above (down below once again):

from im2mesh.encoder import point_plane_net
new_model = point_plane_net.PointPlaneResnet(k = 5).cuda()

#random input
input = torch.randn(10,10, 3).cuda()
labels = torch.randn(10,128).cuda() * 1000
input.size()

optimizer = optim.SGD(new_model.parameters(), lr=0.001, momentum=0.9)
optimizer.zero_grad()
criterion = nn.BCELoss()

for i in range(5):
    a = list(new_model.parameters())[0]
    print(a)
    outputs = new_model(input)
    outputs.shape

    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    b = list(new_model.parameters())[0]
    print(b)

    print(torch.equal(a.data, b.data))

@ptrblck Any feedback on what else I might do?

Hi, what is the final_kernel defined?
I wonder if there is operations in it that breaks the computation graph (as @ptrblck said)

Hi Terence (@b02202050), here is the code

 def kernel(weights, point):
    """calculates H(W, X)
    """
    new_point = torch.cat((torch.tensor([1.]).cuda(), point), dim=0)
    answer = weights.dot(new_point)/(weights[1:].norm(p=2))
    return answer


def final_kernel(points, i, weights, channel, k, id_neighbours):
    """calculates 1/(1 + exp(H(W, X)))
    """
#     print(weights[channel])
    pc_first = 1/k*sum([kernel(weights[channel], points[id_neighbour] - points[i])
                        for id_neighbour in id_neighbours if id_neighbour != i])
    # pc_final = 1/(1. + np.power(2.73, pc_first.numpy()))
    #Previous doesn't work for cuda 
    pc_final = 1/(1. + np.power(2.73, pc_first.item()))
    return pc_final

Good catch @b02202050!

final_kernel will break the computation graph, as you are using numpy functions, which cannot be tracked by Autograd, and you are also detaching pc_first via pc_first.item(), which creates a Python literal.

Try to replace the numpy methods with their PyTorch equivalents and don’t call .item() on your tensors.

2 Likes

I had the same problem and I think it might also be productive for others. Although I set “requires_grad” equal to True, the model didn’t change the parameter value. I used all functions from Pytorch.
I used torch.nn.Parameter and the model consider it this time among parameters that should be passed through the optimizer.
So, I think using torch.nn.Parameter is necessary.
Thanks for your informative answer :blush: