What does next(self.parameters()).data mean?

I found the following piece of code in one of the example of pytorch. Its from an EncoderRNN class.

def init_weights(self, bsz):
    """Initialize weight parameters for the encoder."""
    weight = next(self.parameters()).data
    num_directions = 2 if self.bidirectional else 1
    if self.rnn_type == 'LSTM':
        return (Variable(weight.new(self.n_layers * num_directions, bsz, self.hidden_size).zero_()),
                Variable(weight.new(self.n_layers * num_directions, bsz, self.hidden_size).zero_()))
        return Variable(weight.new(self.n_layers * num_directions, bsz, self.hidden_size).zero_())

I have two question.

  1. What is happening in this statement - weight = next(self.parameters()).data ?
  2. What does weight.new() mean? Why we can’t create a Variable without using weight.new()? Are we somehow attaching the hidden variables to the model (encoder) parameters?
  1. next retrieve the next item from the iterator by calling its next() method.
    • here, it returns the first parameter from the class.
  2. Well, you can create a Variable as you wish, but you need to specify the data type under such circumstance. By new, we constructs a new tensor of the same data type(as the first parameter).

self.parameters() is a generator method that iterates over the parameters of the model.
So weight variable simply holds a parameter of the model.
Then weight.new() creates a tensor that has the same data type, same device as the produced parameter.

class Model(nn.Module):
    def __init__(self):
        self.fc = nn.Linear(10,1)
model = Model()
weight = next(model.parameters()).data

# Output :  tensor([[ 0.0398,  0.0729, -0.2676,  0.2354, -0.0853,  0.1141,  0.0297,  0.0257, \ 
         -0.1303,  0.2208]])

print(type(weight), weight.device, weight.dtype, weight.requires_grad)
# Output : (torch.Tensor, device(type='cpu'), torch.float32, False)

var = weight.new(4,5)

print(type(var), var.device, var.dtype,  var.requires_grad)
# Output : (torch.Tensor, device(type='cpu'), torch.float32, False)

1 Like