Weights are Not converging

Hi

I am trying to take weighted average of weights for last 5 epochs but all of the wights (where require_grad = True) are same.

> class resnet34(nn.Module):
>     def __init__(self):
>         super(resnet34,self).__init__()
>         self.arch = models.resnet34(pretrained=True)
>         self.arch.fc = nn.Linear(self.arch.fc.in_features,32)
>         self.fc1 = nn.Linear(32,10)
>         self.fc2 = nn.Linear(10,1)
>     def forward(self, x):
>         x = self.arch(x)
>         x = self.fc1(x)
>         x = self.fc2(x)
>         return x
> 
> for param in model.arch.parameters():
>     param.requires_grad = False
> 
> pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
> print('Number of parameters',pytorch_total_params)
> Ans is **341**
epochs = 10
model_weights = list()
for epoch in range(epochs):
    train_running_loss = 0
    print("Epoch: {}/{}.. ".format(epoch+1, epochs))
    model.train()
    org_labels_train = list()
    pred_labels_train = list()
    roc_auc_train = 0
    for index,(images,labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device).float()
        ypred = torch.sigmoid(model(images))
        loss = loss_func(ypred,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_running_loss +=loss.item()
        train_losses.append(train_running_loss/len(train_loader))
        org_labels_train.append(labels)
        pred_labels_train.append(ypred)
    ## roc_auc_score(original,predicted)
    roc_auc_train = roc_auc_score(torch.cat(org_labels_train).detach().cpu(),torch.cat(pred_labels_train).detach().cpu())
    print('Roc-auc of Train',epoch+1,'is',np.round(roc_auc_train,4))
    if epoch>=4:
        model_weights.append(model.state_dict())

model_weights[1].get(‘fc2.weight’)

tensor([[ 0.0474, 0.2263, 0.0451, 0.0278, -0.0078, -0.0179, -0.0933, -0.0282,
** 0.0024, -0.0024]], device=‘cuda:0’)**

model_weights[4].get(‘fc2.weight’)

tensor([[ 0.0474, 0.2263, 0.0451, 0.0278, -0.0078, -0.0179, -0.0933, -0.0282,
** 0.0024, -0.0024]], device=‘cuda:0’)**

model.state_dict() holds the references to all parameters and buffers, so you would need to use copy.deepcopy(model.state_dict()) to store the current parameters.

1 Like

@ptrblck Thanks bro.!! Another question how weights are used for prediction the only thing I see is whenever we need to predict we used model like below:

for index,(valid_images,valid_labels) in enumerate(validation_loader):
    valid_images = valid_images.to(device)
    valid_labels = valid_labels.to(device).float()
    ypred_valid =  torch.sigmoid(model(valid_images))

Here model is use is there any way model.state_dict() combine with model for predictions.
ypred_valid = torch.sigmoid(model(valid_images))

Usually I have multiple weights for same model So how I need to combine different models weights with same model archeitecture.

Thanks.!!

After you’ve stored the state_dicts you could iterate the keys of them and create a new state_dict using the mean (or any other reduction) for all parameters.
This code snippet shows a small example:

# Setup
state_dicts = []
model = models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=1.)

# Train and store state_dict
for _ in range(5):
    optimizer.zero_grad()
    out = model(torch.randn(1, 3, 224, 224))
    out.mean().backward()
    optimizer.step()
    state_dicts.append(copy.deepcopy(model.state_dict()))

# Create new state_dict with mean of all params
new_state_dict = collections.OrderedDict()

for key in model.state_dict():
    if 'num_batches_tracked' in key: # handle this separately and reuse last value
        param = state_dicts[-1][key]
    else:
        param = torch.mean(torch.stack([sd[key] for sd in state_dicts]), dim=0)
    new_state_dict[key] = param  

# Load into model
model.load_state_dict(new_state_dict)
1 Like