ValueError: optimizer got an empty parameter list

Hi PyTorch Friends,

I’m trying to building customized layer by following the guide Extending PyTorch Tutorial and use the customized layers to replace the nn.Conv2d and nn.Linear layer in the official example of mnist main.py line 55-59.

However, after replacing with my own customized layers, the testing step (forward) is working without error, while training the new model, it gives an error as “ValueError: optimizer got an empty parameter list”. Also, the new_model.parameters() does not have any items.

The following is my modified Net (nn.Module)

class Decomp_Net(nn.Module):
    def __init__(self, path_pretrained_model="mymodel.pth"):
        super(Decomp_Net, self).__init__()
        # Load the pretrained model
        # Load the saved weights
        self.path_pretrained_model = path_pretrained_model
        try:
            params = torch.load(self.path_pretrained_model)
            print("Loaded pretrained model.")
        except:
            raise("No pretrained model saved.")

        # Conv Layer 1
        self.W_conv1 = params.items()[0]
        self.B_conv1 = params.items()[1][1]
        self.W_conv1 = self.W_conv1[1].view(10, 25)
        self.W_conv1 = self.W_conv1.t()
        self.D_conv1, self.X_a_conv1 = create_dic_fuc.create_dic(A=self.W_conv1, M=25, N=10, Lmax=9, Epsilon=0.7, mode=1)

        # Conv Layer 2
        self.W_conv2 = params.items()[2]
        self.B_conv2 = params.items()[3][1]
        self.W_conv2 = self.W_conv2[1].view(200, 25)
        self.W_conv2 = self.W_conv2.t()
        self.D_conv2, self.X_a_conv2 = create_dic_fuc.create_dic(A=self.W_conv2, M=25, N=200, Lmax=199, Epsilon=0.7, mode=1)

        # Layer FC1
        self.W_fc1 = params.items()[4]
        self.B_fc1 = params.items()[5][1]
        self.D_fc1, self.X_a_fc1 = create_dic_fuc.create_dic(A=self.W_fc1[1], M=50, N=320, Lmax=319, Epsilon=0.8, mode=1)

        # Layer FC2
        self.W_fc2 = params.items()[6] # Feching the last fully connect layer of the orinal model
        self.B_fc2 = params.items()[7][1] 
        self.D_fc2, self.X_a_fc2 = create_dic_fuc.create_dic(A=self.W_fc2[1], M=10, N=50, Lmax=49, Epsilon=0.5, mode=1)

        self.conv1 = ConvDecomp2d(coefs=self.X_a_conv1, dictionary=self.D_conv1, bias_val=self.B_conv1, input_channels=1, output_channels=10, kernel_size=5, bias=True)
        self.conv2 = ConvDecomp2d(coefs=self.X_a_conv2, dictionary=self.D_conv2, bias_val=self.B_conv2, input_channels=10, output_channels=20, kernel_size=5, bias=True)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = FCDecomp(coefs=self.X_a_fc1, dictionary=self.D_fc1, bias_val=self.B_fc1, input_features=320, output_features=50)
        self.fc2 = FCDecomp(coefs=self.X_a_fc2, dictionary=self.D_fc2, bias_val=self.B_fc2, input_features=50, output_features=10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

I defined the customized function and layer as follows:

class LinearDecomp(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
def forward(ctx, input, coefs, dictionary, bias=None):
    weight = torch.mm(dictionary, coefs).cuda() # reconstruct the weight
    ctx.save_for_backward(input, weight, dictionary, coefs, bias)
    output = input.mm(weight.t())
    if bias is not None:
        output += bias.unsqueeze(0).expand_as(output)
    return output

# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
    input, weight, coefs, dictionary, bias = ctx.saved_variables
    grad_input = grad_input = grad_coefs = grad_bias = None
    grad_weight = grad_output.t().mm(input) # do not output

    if ctx.needs_input_grad[0]:
        grad_input = grad_output.mm(weight)

    # if ctx.needs_input_grad[1]:
    grad_weight = grad_output.t().mm(input) # do not output grad_weight

    if ctx.needs_input_grad[2]:
        grad_coefs = dictionary.t().mm(grad_weight)

    if ctx.needs_input_grad[3]:
        grad_dictionary = grad_weight.t().mm(grad_coefs.t())

    if bias is not None and ctx.needs_input_grad[4]:
        grad_bias = grad_output.sum(0).squeeze(0)

    return grad_input, grad_coefs, grad_dictionary, grad_bias

The customized layer is defined as:

class FCDecomp(nn.Module):
        def __init__(self, coefs, dictionary, bias_val, input_features, output_features, bias=True):
            super(FCDecomp, self).__init__()
            self.dictionary = nn.Parameter(dictionary, requires_grad=False).cuda()
            self.coefs = nn.Parameter(coefs, requires_grad=True).cuda()
            if bias:
                self.bias = nn.Parameter(bias_val, requires_grad=True).cuda()
            else:
                self.register_parameter('bias', None)

        def forward(self, input):
            return LinearDecomp.apply(input, self.coefs, self.dictionary, self.bias)

Could anyone provide me some suggestion or hints for this issue? Thank you very much!

My guess is that your saved file path_pretrained_model doesn’t contain nn.Parameters. nn.Parameter is a subclass of torch.autograd.Variable that marks it as an optimizable parameter (i.e. it’s returned by model.parameters().

If your path_pretrained_model contains Tensors, change your code to something like:

self.W_conv1 = nn.Parameter(params.items()[0])

If your path_pretrained_model contains torch.autograd.Variables, change your code to something like:

self.W_conv1 = nn.Parameter(params.items()[0].data)

Also, are you sure you need a custom autograd function for LinearDecomp? Can you just write it as a normal Python function leveraging autograd?

i.e.

def linear_decomp(input, coefs, dictionary, bias=None):
   weight = torch.mm(dictionary, coefs).cuda()
   output = input.mm(weight.t())
   if bias is not None:
     output += bias.unsqueeze(0)
   return output

Thank you very much for quick responding! However, the issue has not fix on my side yet.

  1. the path_pretrained_model I passed into the Net is only for fetching the pre-trained weights and bias in different layers. I’m not training directly based on the original model.
  2. I did transfer the dictionary, coef, bias into nn.Parameter(, requires_grad=True/False) where I defined the layer FCDecomp(nn.Module)
  3. I still followed your suggestion to transfer Variables into nn.Parameter() in the Net level while defining the Net, but it still gives me the same error: ValueError: optimizer got an empty parameter list

Here is my modified Net:

class Decomp_Net(nn.Module):
def __init__(self, path_pretrained_model="mymodel.pth"):
    super(Decomp_Net, self).__init__()
    # Load the pretrained model
    # Load the saved weights
    self.path_pretrained_model = path_pretrained_model
    try:
        params = torch.load(self.path_pretrained_model)
        print("Loaded pretrained model.")
    except:
        raise("No pretrained model saved.")

    # Conv Layer 1
    self.W_conv1 = params.items()[0]
    self.B_conv1 = nn.Parameter(params.items()[1][1], requires_grad=True).cuda()
    self.W_conv1 = self.W_conv1[1].view(10, 25)
    self.W_conv1 = self.W_conv1.t()
    self.D_conv1, self.X_a_conv1 = create_dic_fuc.create_dic(A=self.W_conv1, M=25, N=10, Lmax=9, Epsilon=0.7, mode=1)
    self.D_conv1 = nn.Parameter(self.D_conv1, requires_grad=False).cuda()
    self.X_a_conv1 = nn.Parameter(self.X_a_conv1, requires_grad=True).cuda()

    # Conv Layer 2
    self.W_conv2 = params.items()[2]
    self.B_conv2 = nn.Parameter(params.items()[3][1], requires_grad=True).cuda()
    self.W_conv2 = self.W_conv2[1].view(200, 25)
    self.W_conv2 = self.W_conv2.t()
    self.D_conv2, self.X_a_conv2 = create_dic_fuc.create_dic(A=self.W_conv2, M=25, N=200, Lmax=199, Epsilon=0.7, mode=1)
    self.D_conv2 = nn.Parameter(self.D_conv2, requires_grad=False).cuda()
    self.X_a_conv2 = nn.Parameter(self.X_a_conv2, requires_grad=True).cuda()

    # Layer FC1
    self.W_fc1 = params.items()[4][1]
    self.B_fc1 = nn.Parameter(params.items()[5][1], requires_grad=True).cuda()
    self.D_fc1, self.X_a_fc1 = create_dic_fuc.create_dic(A=self.W_fc1, M=50, N=320, Lmax=319, Epsilon=0.8, mode=1)
    self.D_fc1 = nn.Parameter(self.D_fc1, requires_grad=False).cuda()
    self.X_a_fc1 = nn.Parameter(self.X_a_fc1, requires_grad=True).cuda()

    # Layer FC2
    self.W_fc2 = params.items()[6][1] # Feching the last fully connect layer of the orinal model
    self.B_fc2 = nn.Parameter(params.items()[7][1], requires_grad=True).cuda()
    self.D_fc2, self.X_a_fc2 = create_dic_fuc.create_dic(A=self.W_fc2, M=10, N=50, Lmax=49, Epsilon=0.5, mode=1)
    self.D_fc2 = nn.Parameter(self.D_fc2, requires_grad=False).cuda()
    self.X_a_fc2 = nn.Parameter(self.X_a_fc2, requires_grad=True).cuda()

    self.conv1 = ConvDecomp2d(coefs=self.X_a_conv1, dictionary=self.D_conv1, bias_val=self.B_conv1, input_channels=1, output_channels=10, kernel_size=5, bias=True)
    self.conv2 = ConvDecomp2d(coefs=self.X_a_conv2, dictionary=self.D_conv2, bias_val=self.B_conv2, input_channels=10, output_channels=20, kernel_size=5, bias=True)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = FCDecomp(coefs=self.X_a_fc1, dictionary=self.D_fc1, bias_val=self.B_fc1, input_features=320, output_features=50)
    self.fc2 = FCDecomp(coefs=self.X_a_fc2, dictionary=self.D_fc2, bias_val=self.B_fc2, input_features=50, output_features=10)

def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x)

Custom layer defined:

class FCDecomp(nn.Module):
def __init__(self, coefs, dictionary, bias_val, input_features, output_features, bias=True):
    super(FCDecomp, self).__init__()
    self.dictionary = dictionary
    self.coefs = coefs
    if bias:
        self.bias = bias_val
    else:
        self.register_parameter('bias', None)

def forward(self, input):
    return LinearDecomp.apply(input, self.coefs, self.dictionary, self.bias)

Custom defined function:

class LinearDecomp(Function):

# Note that both forward and backward are @staticmethods
@staticmethod
def forward(ctx, input, coefs, dictionary, bias=None):
    #ctx.save_for_backward(input, weight, bias)
    weight = torch.mm(dictionary, coefs).cuda() # reconstruct the weight
    ctx.save_for_backward(input, weight, dictionary, coefs, bias)
    # output = input.mm(weight.t())
    output = input.mm(weight.t())
    if bias is not None:
        output += bias.unsqueeze(0).expand_as(output)
    return output

# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
    input, weight, coefs, dictionary, bias = ctx.saved_variables
    grad_input = grad_input = grad_coefs = grad_bias = None
    grad_weight = grad_output.t().mm(input) # do not output

    if ctx.needs_input_grad[0]:
        grad_input = grad_output.mm(weight)

    # if ctx.needs_input_grad[1]:
    grad_weight = grad_output.t().mm(input) # do not output grad_weight

    if ctx.needs_input_grad[2]:
        grad_coefs = dictionary.t().mm(grad_weight)

    if ctx.needs_input_grad[3]:
        grad_dictionary = grad_weight.t().mm(grad_coefs.t())

    if bias is not None and ctx.needs_input_grad[4]:
        grad_bias = grad_output.sum(0).squeeze(0)

    return grad_input, grad_coefs, grad_dictionary, grad_bias

The reason I need a custom autograd function for LinearDecomp is that I only want to forcefully autograd train coefs and bias but not dictionary. So, I need to fix the dictionary requires_grad=False. (As you might notice in the class Decomp_Net(nn.Module)) -> def init

I have

self.B_conv1 = nn.Parameter(params.items()[1][1], requires_grad=True).cuda()
self.D_conv1 = nn.Parameter(self.D_conv1, requires_grad=False).cuda()
self.X_a_conv1 = nn.Parameter(self.X_a_conv1, requires_grad=True).cuda()
  1. Move the .cuda() call into the nn.Parameter() call:
self.B_conv1 = nn.Parameter(params.items()[1][1].cuda())
  1. If you don’t want to backprop through the dictionary it’s easy:
def linear_decomp(input, coefs, dictionary, bias=None):
   weight = torch.mm(dictionary.detach(), coefs).cuda()
   output = input.mm(weight.t())
   if bias is not None:
     output += bias.unsqueeze(0)
   return output
1 Like

Than you so much for the help! It does solve the issue. Indeed, I don’t need a custom function. Autograd is quite intelligent. :slight_smile: