I am using Swish activation function, with trainable 𝛽 parameter according to the paper SWISH: A Self-Gated Activation Function paper by Prajit Ramachandran, Barret Zoph and Quoc V. Le. I am using LeNet-5 CNN as a toy example on MNIST to train ‘beta’ instead of using beta = 1 as present in nn.SiLU().
The example code is:
class LeNet5(nn.Module):
def __init__(self, beta = 1.0):
super(LeNet5, self).__init__()
b = torch.tensor(data = beta, dtype = torch.float32)
self.beta = torch.autograd.Variable(b, requires_grad = True)
self.conv1 = nn.Conv2d(
in_channels = 1, out_channels = 6,
kernel_size = 5, stride = 1,
padding = 0, bias = False
)
self.bn1 = nn.BatchNorm2d(num_features = 6)
self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
self.conv2 = nn.Conv2d(
in_channels = 6, out_channels = 16,
kernel_size = 5, stride = 1,
padding = 0, bias = False
)
self.bn2 = nn.BatchNorm2d(num_features = 16)
self.fc1 = nn.Linear(
in_features = 256, out_features = 120,
bias = True
)
self.bn3 = nn.BatchNorm1d(num_features = 120)
self.fc2 = nn.Linear(
in_features = 120, out_features = 84,
bias = True
)
self.bn4 = nn.BatchNorm1d(num_features = 84)
self.fc3 = nn.Linear(
in_features = 84, out_features = 10,
bias = True
)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
# print(m)
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
# Do not initialize bias (due to batchnorm)-
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
# Standard initialization for batch normalization-
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def swish_fn(self, x):
return x * torch.sigmoid(x * self.beta)
def forward(self, x):
'''
x = nn.SiLU()(self.pool1(self.bn1(self.conv1(x))))
x = nn.SiLU()(self.pool1(self.bn2(self.conv2(x))))
x = x.view(-1, 256)
x = nn.SiLU()(self.bn3(self.fc1(x)))
x = nn.SiLU()(self.bn4(self.fc2(x)))
'''
x = self.pool(self.bn1(self.conv1(x)))
x = self.swish_fn(x = x)
x = self.pool(self.bn2(self.conv2(x)))
x = self.swish_fn(x = x)
x = x.view(-1, 256)
x = self.bn3(self.fc1(x))
x = self.swish_fn(x = x)
x = self.bn4(self.fc2(x))
x = self.swish_fn(x = x)
x = self.fc3(x)
return x
# Initialize an instance of LeNet-5 CNN architecture-
model = LeNet5(beta = 1.0).to(device)
# Define cost function-
loss = nn.CrossEntropyLoss()
# Defing SGD optimizer-
optimizer = torch.optim.SGD(
params = model.parameters(), lr = 0.1,
momentum = 0.9, weight_decay = 5e-4
)
# Decay lr at 20th, 40th, 60th and 75th epochs by a factor of 10-
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer = optimizer, milestones = [20, 40, 60, 75],
gamma = 0.1
)
def train_one_step(model, train_loader, train_dataset):
running_loss = 0.0
running_corrects = 0.0
model.to(device)
model.train()
with tqdm(train_loader, unit = 'batch') as tepoch:
for images, labels in tepoch:
tepoch.set_description(f"Training: ")
images = images.to(device)
labels = labels.to(device)
# Get model predictions-
outputs = model(images)
# Compute loss-
J = loss(outputs, labels)
# Empty accumulated gradients-
optimizer.zero_grad()
# Perform backprop-
J.backward()
# Update parameters-
optimizer.step()
# Compute model's performance statistics-
running_loss += J.item() * images.size(0)
_, predicted = torch.max(outputs, 1)
running_corrects += torch.sum(predicted == labels.data)
tepoch.set_postfix(
loss = running_loss / len(train_dataset),
accuracy = (running_corrects.double().cpu().numpy() / len(train_dataset)) * 100
)
train_loss = running_loss / len(train_dataset)
train_acc = (running_corrects.double() / len(train_dataset)) * 100
# return running_loss, running_corrects
return train_loss, train_acc.detach().cpu().item()
def test_one_step(model, test_loader, test_dataset):
total = 0.0
correct = 0.0
running_loss_val = 0.0
model.to(device)
model.eval()
with torch.no_grad():
with tqdm(test_loader, unit = 'batch') as tepoch:
for images, labels in tepoch:
tepoch.set_description(f"Validation: ")
images = images.to(device)
labels = labels.to(device)
# Predict using trained model-
outputs = model(images)
_, y_pred = torch.max(outputs, 1)
# Compute validation loss-
J_val = loss(outputs, labels)
running_loss_val += J_val.item() * labels.size(0)
# Total number of labels-
total += labels.size(0)
# Total number of correct predictions-
correct += (y_pred == labels).sum()
tepoch.set_postfix(
val_loss = running_loss_val / len(test_dataset),
val_acc = 100 * (correct.cpu().numpy() / total)
)
# return (running_loss_val, correct, total)
val_loss = running_loss_val / len(test_dataset)
val_acc = (correct / total) * 100
return val_loss, val_acc.detach().cpu().item()
# Python3 dict to contain training metrics-
train_history = {}
# Variable to store 'best' model-
best_val_acc = 0
While training the model, I am printing ‘beta’ as:
for epoch in range(1, num_epochs + 1):
# One epoch of training-
train_loss, train_acc = train_one_step(
model = model, train_loader = train_loader,
train_dataset = train_dataset
)
# Get validation metrics after 1 epoch of training-
val_loss, val_acc = test_one_step(
model = model, test_loader = test_loader,
test_dataset = test_dataset
)
scheduler.step()
current_lr = optimizer.param_groups[0]["lr"]
print(f"Epoch: {epoch}; loss = {train_loss:.4f}, acc = {train_acc:.2f}%",
f" val loss = {val_loss:.4f}, val acc = {val_acc:.2f}%,"
f" beta = {model.beta:.6f}, beta grad = {model.beta.grad:.6f}"
f" & LR = {current_lr:.5f}"
)
# Save training metrics to Python3 dict-
train_history[epoch] = {
'train_loss': train_loss, 'val_loss': val_loss,
'train_acc': train_acc, 'val_acc': val_acc,
'lr': current_lr
}
# Save model with best validation accuracy-
if (val_acc > best_val_acc):
best_val_acc = val_acc
print(f"Saving model with highest val_acc = {val_acc:.2f}%\n")
torch.save(model.state_dict(), "LeNet5_MNIST_best_val_acc.pth")
The problem is that “beta” parameter in LeNet5() instance is not being updated. What am I doing wrong? Why isn’t beta training as expected?
During the epochs, beta is fixed at 1.0, whereas, beta grad shows gradient updates. But, “beta” parameter is still not training.
This post is SOLVED!