If we combine one trainable parameters with a non-trainable parameter, is the original trainable param trainable?

Say I have two nets and I combine their parameters in some fancy way using only pytorch operations. I store the result in a third net which has its parameters set to non-trainable. Then I proceed and pass data through this new net. The new net is just a place holder for

placeholder_net.W = Op( not_trainable_net.W, trainable_net.W )

Then I pass data:

output = placeholder_net(input)

I am concerned that since the params of the placeholder net are set to not trainable that it won’t actually train the variable that it should train. Will this happen? Or what is the result when you combine a trainable param with and non trainable param (and then set that where the param is not trainable)?


interesting it seems it can’t even be done in 0.3.1?

(Pdb) net3.conv0.weight = net.conv0.weight + net2.conv0.weight
*** TypeError: cannot assign 'torch.autograd.variable.Variable' as parameter 'weight' (torch.nn.Parameter or None expected)

Functions using placeholder_net.W will back-propagate through to trainable_net.W. Note that you won’t be able to optimize placeholder_net.W because it will be an intermediate Variable, not a parameter.

To avoid the TypeError, do:

del net3.conv0.weight
net3.conv0.weight = net.conv0.weight + net2.conv0.weight

This will remove weight as an nn.Parameter and add it back as a (non-optimizable) Variable.

hmmm but will it be able to optimize trainable_net.W? Thats what Im mostly concerned. I will be as the output and then pass it to criterion and THEN backprop to optimize only thing in trainable_net.W. Will that work is my main concern…

Yes, that will work.

hmmm but I think I would need to pass to the optimizer the weights of the trainable net, right? Otherwise who knows what would happen…does it matter if the params of the placeholder net are trainable or not? I assume not since they would be substituted by a trainable param (i.e. the combination of one trianable and another non trainable makes the resulting trainable).

do we really need the del command? this seems rather innefficient…or is it? What do u recommend to make this code efficient?

I guess I was worried that if the original weights were non trianable that putting W_trainable + W_non_trainable would become non trainable…but the del would actually delete the old instance/object so that the requires_grad is set to True.

how do you implement that code you suggested if what I have is the string name of the layer conv0?

perhaps:

setattr(self,f'bn2D_conv{i}',bn)

Use delattr instead of del then

do I really need to delete it before using setattr?

oh no now I can’t because the attributed is ‘conv0.weight’ XD can’t find it…

Do we actually need to delete the attribute though? what goes wrong if I don’t?

Isn’t what we really need to make sure is right that the right variables are inserted in the computation tree for backwards computations to be done right, correct? Can’t be this achieved with setting placeholder net to eval and so the non-trainable net?

The placeholder net is only needed so that forward computation to be done right because it holds the combination.

Ok filed a bug with reproducible code:

code:

import torch
from torch import nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

from collections import OrderedDict

import copy

def dont_train(net):
    '''
    set training parameters to false.
    '''
    for param in net.parameters():
        param.requires_grad = False
    return net

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
    classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    return trainloader,classes

def combine_nets(net_train,net_no_train,net_place_holder):
    '''
        Combine nets in a way train net is trainable
    '''
    params_train = net_no_train.named_parameters()
    dict_params_place_holder = dict( net_place_holder.named_parameters() )
    dict_params_no_train = dict(net_train.named_parameters())
    for name,param_train in params_train:
        if name in dict_params_place_holder:
            param_no_train = dict_params_no_train[name]
            delattr(net_place_holder, name)
            W_new = param_train + param_no_train # notice addition is just chosen for the sake of an example
            setattr(net_place_holder, name, W_new)
    return net_place_holder

def combining_nets_lead_to_error():
    '''
    Intention is to only train the net with trainable params.
    Placeholde rnet is a dummy net, it doesn't actually do anything except hold the combination of params and its the
    net that does the forward pass on the data.
    '''
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ''' create three musketeers '''
    net_train = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ])).to(device)
    net_no_train = copy.deepcopy(net_train).to(device)
    net_place_holder = copy.deepcopy(net_train).to(device)
    ''' prepare train, hyperparams '''
    trainloader,classes = get_cifar10()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net_train.parameters(), lr=0.001, momentum=0.9)
    ''' train '''
    net_train.train()
    net_no_train.eval()
    net_place_holder.eval()
    for epoch in range(2):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(trainloader, 0):
            optimizer.zero_grad() # zero the parameter gradients
            inputs, labels = inputs.to(device), labels.to(device)
            # combine nets
            net_place_holder = combine_nets(net_train,net_no_train,net_place_holder)
            #
            outputs = net_place_holder(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    ''' DONE '''
    print('Done \a')

if __name__ == '__main__':
    combining_nets_lead_to_error()

@ptrblck I know this is sort invasive to just tag you here…but I was wondering if you knew how to answer this question? I’ve not been able to have it work for some time and was hoping some pytorch expert had some insight on whats going wrong… :frowning: no pressure though… :slight_smile:

Nested attributes are sometimes hard to handle.
You could change your inner code of combine_nets this as a workaround:

if name in dict_params_place_holder:
    param_no_train = dict_params_no_train[name]
    parent, child = name.split('.')
    delattr(getattr(net_place_holder, parent), child)
    W_new = param_train + param_no_train # notice addition is just chosen for the sake of an example
    setattr(getattr(net_place_holder, parent), child, W_new)

This should solve the error.
However, it seems your input has 3 channels, while your first conv layer just takes 1.
Probably you should change this as well.

1 Like