Why SWA is not changing learning rate?

Hi, I am using Stochastic Weight Averaging from torch contrib for the first time.
i noticed that my wavenet model’s learning rate is not changing.

model code :

# from https://www.kaggle.com/hanjoonchoe/wavenet-lstm-pytorch-ignite-ver        

class Wave_Block(nn.Module):

    

    def __init__(self,in_channels,out_channels,dilation_rates):

        super(Wave_Block,self).__init__()

        self.num_rates = dilation_rates

        self.convs = nn.ModuleList()

        self.filter_convs = nn.ModuleList()

        self.gate_convs = nn.ModuleList()

        

        self.convs.append(nn.Conv1d(in_channels,out_channels,kernel_size=1))

        dilation_rates = [2**i for i in range(dilation_rates)]

        for dilation_rate in dilation_rates:

            self.filter_convs.append(nn.Conv1d(out_channels,out_channels,kernel_size=3,padding=dilation_rate,dilation=dilation_rate))

            self.gate_convs.append(nn.Conv1d(out_channels,out_channels,kernel_size=3,padding=dilation_rate,dilation=dilation_rate))

            self.convs.append(nn.Conv1d(out_channels,out_channels,kernel_size=1))

            

    def forward(self,x):

        x = self.convs[0](x)

        res = x

        for i in range(self.num_rates):

            x = F.tanh(self.filter_convs[i](x))*F.sigmoid(self.gate_convs[i](x))

            x = self.convs[i+1](x)

            res = res + x

        return res

        
    

class Classifier(nn.Module):

    def __init__(self):

        super().__init__()

        input_size = 128

        self.LSTM = nn.GRU(input_size=input_size,hidden_size=64,num_layers=2,batch_first=True,bidirectional=True)

        #self.attention = Attention(input_size,4000)

        #self.rnn = nn.RNN(input_size, 64, 2, batch_first=True, nonlinearity='relu')

        self.wave_block1 = Wave_Block(19,16,12)

        self.wave_block2 = Wave_Block(16,32,8)

        self.wave_block3 = Wave_Block(32,64,4)

        self.wave_block4 = Wave_Block(64, 128, 1)

        self.fc = nn.Linear(128, 11)

            

    def forward(self,x):

        x = x.permute(0, 2, 1)

        

        x = self.wave_block1(x)

        x = self.wave_block2(x)

        x = self.wave_block3(x)

        

        #x,_ = self.LSTM(x)

        x = self.wave_block4(x)

        x = x.permute(0, 2, 1)

        x,_ = self.LSTM(x)

        #x = self.conv1(x)

        #print(x.shape)

        #x = self.rnn(x)

        #x = self.attention(x)

        x = self.fc(x)

        return x

       

class EarlyStopping:

    def __init__(self, patience=5, delta=0, checkpoint_path='checkpoint.pt', is_maximize=True):

        self.patience, self.delta, self.checkpoint_path = patience, delta, checkpoint_path

        self.counter, self.best_score = 0, None

        self.is_maximize = is_maximize

    def load_best_weights(self, model):

        model.load_state_dict(torch.load(self.checkpoint_path))

    def __call__(self, score, model):

        if self.best_score is None or \

                (score > self.best_score + self.delta if self.is_maximize else score < self.best_score - self.delta):

            torch.save(model.state_dict(), self.checkpoint_path)

            self.best_score, self.counter = score, 0

            return 1

        else:

            self.counter += 1

            if self.counter >= self.patience:

                return 2

        return 0

train & validation loop :

test_y = np.zeros([int(2000000/GROUP_BATCH_SIZE), GROUP_BATCH_SIZE, 1])

test_dataset = IronDataset(test, test_y, flip=False)

test_dataloader = DataLoader(test_dataset, NNBATCHSIZE, shuffle=False, num_workers=8, pin_memory=True)

test_preds_all = np.zeros((2000000, 11))

oof_score = []

for index, (train_index, val_index, _) in enumerate(new_splits[0:], start=0):

    print("Fold : {}".format(index))

    train_dataset = IronDataset(train[train_index], train_tr[train_index], seq_len=GROUP_BATCH_SIZE, flip=flip, noise_level=noise)

    train_dataloader = DataLoader(train_dataset, NNBATCHSIZE, shuffle=True, num_workers=8, pin_memory=True)

    valid_dataset = IronDataset(train[val_index], train_tr[val_index], seq_len=GROUP_BATCH_SIZE, flip=False)

    valid_dataloader = DataLoader(valid_dataset, NNBATCHSIZE, shuffle=False, num_workers=4, pin_memory=True)

    it = 0

    model = Classifier()

    model = model.cuda()

    early_stopping = EarlyStopping(patience=40, is_maximize=True,

                                   checkpoint_path=os.path.join(outdir, "gru_clean_checkpoint_fold_{}_iter_{}.pt".format(index,

                                                                                                             it)))

    weight = None#cal_weights()

    criterion = nn.CrossEntropyLoss(weight=weight)

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    

    optimizer = torchcontrib.optim.SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.002)

    schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.2)

    avg_train_losses, avg_valid_losses = [], []

    

    for epoch in range(EPOCHS):

        print('**********************************')

        print("Folder : {} Epoch : {}".format(index, epoch))

        print("Curr learning_rate: {:0.9f}".format(optimizer.param_groups[0]['lr']))

        train_losses, valid_losses = [], []

        tr_loss_cls_item, val_loss_cls_item = [], []

        model.train()  # prep model for training

        train_preds, train_true = torch.Tensor([]).cuda(), torch.LongTensor([]).cuda()#.to(device)

        for x, y in tqdm(train_dataloader):

            x = x.cuda()

            y = y.cuda()

            #print(x.shape)

            

            optimizer.zero_grad()

            #loss_fn(model(input), target).backward()

                            

            #optimizer.zero_grad()

            predictions = model(x)

            predictions_ = predictions.view(-1, predictions.shape[-1])

            y_ = y.view(-1)

            loss = criterion(predictions_, y_)

            # backward pass: compute gradient of the loss with respect to model parameters

            loss.backward()

            # perform a single optimization step (parameter update)

            optimizer.step()

            

            #schedular.step()

            # record training lossa

            train_losses.append(loss.item())

            train_true = torch.cat([train_true, y_], 0)

            train_preds = torch.cat([train_preds, predictions_], 0)

        #model.eval()  # prep model for evaluation

        optimizer.swap_swa_sgd()

        val_preds, val_true = torch.Tensor([]).cuda(), torch.LongTensor([]).cuda()

        print('EVALUATION')

        with torch.no_grad():

            for x, y in tqdm(valid_dataloader):

                x = x.cuda()#.to(device)

                y = y.cuda()#..to(device)

                predictions = model(x)

                predictions_ = predictions.view(-1, predictions.shape[-1])

                y_ = y.view(-1)

                loss = criterion(predictions_, y_)

                valid_losses.append(loss.item())

                val_true = torch.cat([val_true, y_], 0)

                val_preds = torch.cat([val_preds, predictions_], 0)

        # calculate average loss over an epoch

        train_loss = np.average(train_losses)

        valid_loss = np.average(valid_losses)

        avg_train_losses.append(train_loss)

        avg_valid_losses.append(valid_loss)

        print("train_loss: {:0.6f}, valid_loss: {:0.6f}".format(train_loss, valid_loss))

        train_score = f1_score(train_true.cpu().detach().numpy(), train_preds.cpu().detach().numpy().argmax(1),

                               labels=list(range(11)), average='macro')

        val_score = f1_score(val_true.cpu().detach().numpy(), val_preds.cpu().detach().numpy().argmax(1),

                             labels=list(range(11)), average='macro')

        schedular.step(val_score)

        print("train_f1: {:0.6f}, valid_f1: {:0.6f}".format(train_score, val_score))

        res = early_stopping(val_score, model)

        #print('fres:', res)

        if  res == 2:

            print("Early Stopping")

            print('folder %d global best val max f1 model score %f' % (index, early_stopping.best_score))

            break

        elif res == 1:

            print('save folder %d global val max f1 model score %f' % (index, val_score))

    print('Folder {} finally best global max f1 score is {}'.format(index, early_stopping.best_score))

    oof_score.append(round(early_stopping.best_score, 6))

    

    model.eval()

    pred_list = []

    with torch.no_grad():

        for x, y in tqdm(test_dataloader):

            

            x = x.cuda()

            y = y.cuda()

            predictions = model(x)

            predictions_ = predictions.view(-1, predictions.shape[-1]) # shape [128, 4000, 11]

            #print(predictions.shape, F.softmax(predictions_, dim=1).cpu().numpy().shape)

            pred_list.append(F.softmax(predictions_, dim=1).cpu().numpy()) # shape (512000, 11)

            #a = input()

        test_preds = np.vstack(pred_list) # shape [2000000, 11]

        test_preds_all += test_preds

How does val_score look?
Since you have set patience=3, the val_score would have to be static or decreasing for 3 epochs.

@ptrblck

here is the updated code :

test_y = np.zeros([int(2000000/GROUP_BATCH_SIZE), GROUP_BATCH_SIZE, 1])
test_dataset = IronDataset(test, test_y, flip=False)
test_dataloader = DataLoader(test_dataset, NNBATCHSIZE, shuffle=False)
test_preds_all = np.zeros((2000000, 11))


oof_score = []
for index, (train_index, val_index, _) in enumerate(new_splits[0:], start=0):
    print("Fold : {}".format(index))
    train_dataset = IronDataset(train[train_index], train_tr[train_index], seq_len=GROUP_BATCH_SIZE, flip=flip, noise_level=noise)
    train_dataloader = DataLoader(train_dataset, NNBATCHSIZE, shuffle=True,num_workers = 16)

    valid_dataset = IronDataset(train[val_index], train_tr[val_index], seq_len=GROUP_BATCH_SIZE, flip=False)
    valid_dataloader = DataLoader(valid_dataset, NNBATCHSIZE, shuffle=False)

    it = 0
    model = Classifier()
    model = model.cuda()

    early_stopping = EarlyStopping(patience=40, is_maximize=True,
                                   checkpoint_path=os.path.join(outdir, "gru_clean_checkpoint_fold_{}_iter_{}.pt".format(index,
                                                                                                             it)))

    weight = None#cal_weights()
    criterion = nn.CrossEntropyLoss(weight=weight)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    optimizer = torchcontrib.optim.SWA(optimizer, swa_start=10, swa_freq=2, swa_lr=0.002)
    
    


    schedular = torch.optim.lr_scheduler.CyclicLR(optimizer,base_lr=LR, max_lr=0.003, step_size_up=len(train_dataset)/2, cycle_momentum=False)
    avg_train_losses, avg_valid_losses = [], []

    

    for epoch in range(EPOCHS):
        
        train_losses, valid_losses = [], []
        tr_loss_cls_item, val_loss_cls_item = [], []

        model.train()  # prep model for training
        train_preds, train_true = torch.Tensor([]).cuda(), torch.LongTensor([]).cuda()#.to(device)
        
        print('**********************************')
        print("Folder : {} Epoch : {}".format(index, epoch))
        print("Curr learning_rate: {:0.9f}".format(optimizer.param_groups[0]['lr']))
        
            #loss_fn(model(input), target).backward()
        for x, y in tqdm(train_dataloader):
            x = x.cuda()
            y = y.cuda()
            #print(x.shape)
            
         
            
            optimizer.zero_grad()
            predictions = model(x)

            predictions_ = predictions.view(-1, predictions.shape[-1])
            y_ = y.view(-1)

            loss = criterion(predictions_, y_)

            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            
            schedular.step()
            # record training lossa
            train_losses.append(loss.item())
            train_true = torch.cat([train_true, y_], 0)
            train_preds = torch.cat([train_preds, predictions_], 0)

        #model.eval()  # prep model for evaluation
        
        optimizer.update_swa()
        optimizer.swap_swa_sgd()
        val_preds, val_true = torch.Tensor([]).cuda(), torch.LongTensor([]).cuda()
        print('EVALUATION')
        with torch.no_grad():
            for x, y in tqdm(valid_dataloader):
                x = x.cuda()#.to(device)
                y = y.cuda()#..to(device)

                predictions = model(x)
                predictions_ = predictions.view(-1, predictions.shape[-1])
                y_ = y.view(-1)

                loss = criterion(predictions_, y_)

                valid_losses.append(loss.item())


                val_true = torch.cat([val_true, y_], 0)
                val_preds = torch.cat([val_preds, predictions_], 0)
 
        
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        print("train_loss: {:0.6f}, valid_loss: {:0.6f}".format(train_loss, valid_loss))

        train_score = f1_score(train_true.cpu().detach().numpy(), train_preds.cpu().detach().numpy().argmax(1),
                               labels=list(range(11)), average='macro')

        val_score = f1_score(val_true.cpu().detach().numpy(), val_preds.cpu().detach().numpy().argmax(1),
                             labels=list(range(11)), average='macro')

        schedular.step(val_score)
        print("train_f1: {:0.6f}, valid_f1: {:0.6f}".format(train_score, val_score))
        res = early_stopping(val_score, model)
        #print('fres:', res)
        if  res == 2:
            print("Early Stopping")
            print('folder %d global best val max f1 model score %f' % (index, early_stopping.best_score))
            break
        elif res == 1:
            print('save folder %d global val max f1 model score %f' % (index, val_score))
    print('Folder {} finally best global max f1 score is {}'.format(index, early_stopping.best_score))
    oof_score.append(round(early_stopping.best_score, 6))
    
    model.eval()
    pred_list = []
    with torch.no_grad():
        for x, y in tqdm(test_dataloader):
            
            x = x.cuda()
            y = y.cuda()

            predictions = model(x)
            predictions_ = predictions.view(-1, predictions.shape[-1]) # shape [128, 4000, 11]
            #print(predictions.shape, F.softmax(predictions_, dim=1).cpu().numpy().shape)
            pred_list.append(F.softmax(predictions_, dim=1).cpu().numpy()) # shape (512000, 11)
            #a = input()
        test_preds = np.vstack(pred_list) # shape [2000000, 11]
        test_preds_all += test_preds

please let me know if the implementation is correct or not? i am using cyclicLr now,if it is error free then i have 2 more questions
question 1 : how do i use optimizer.bn_update() in that code? can you tell me where i can use that line of code in my updated model?
question 2 : swa_lr and scheduler learning rate are same? i want to change learning rate after some epoch using but not understanding if i need to update the swa_lr parameter after some epoch on scheduler,any example on such case?