I am trying to train a model where I want to apply a function to the current model weights and then calculate the loss.
But using this loss, I want to update the original weights.
I am doing something like this. I am unsure if I am achieving what I am trying to do, as the trained model is not optimized if I add the same noise into the trained model.
Code
import torch
import torchvision
from torch import nn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)
def create_network():
channel = [784, 100, 100, 10]
model = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=channel[0], out_features=channel[1]),
nn.Sigmoid(),
nn.Linear(in_features=channel[1], out_features=channel[2]),
nn.Sigmoid(),
nn.Linear(in_features=channel[2], out_features=channel[3]),
nn.LogSoftmax(dim=1),
)
return model
# function to apply a transformation to weights
def transform_weights(model):
for name, param in model.named_parameters():
if "weight" in name:
# create a new random tensor with the same size as the weight tensor
noise = torch.randn(param.shape) * 0.01
param.data = param.data + noise.to(device)
return model
# Load MNIST dataset
train_dataset = torchvision.datasets.MNIST(
root="data", train=True, download=True, transform=torchvision.transforms.ToTensor()
)
test_dataset = torchvision.datasets.MNIST(
root="data", train=False, download=True, transform=torchvision.transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
model = create_network() # the model I want to train
model.to(device) # move the model to the GPU
model_orig = create_network() # The model to to store the wights before adding noise
model_orig.to(device) # move the model to the GPU
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # only train model
for epoch in range(1, 10):
for batch_idx, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# reset gradients
optimizer.zero_grad()
model_orig.load_state_dict(model.state_dict())
# drift the weights and compute the forward pass
model.eval()
model = transform_weights(model)
loss = criterion(model(images), labels)
# Run training (backward propagation).
loss.backward()
# Load back the original weights
model.load_state_dict(model_orig.state_dict())
model.train()
# Optimize weights.
optimizer.step()
# Calculate the test accuracy
if batch_idx % 100 == 0:
correct = 0
total = 0
with torch.no_grad():
for data in test_dataset:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(
f"Epoch: {epoch} Batch: {batch_idx} Loss: {loss.item()} Accuracy: {100 * correct / total}"
)
Output
Using device: cuda:0
Epoch: 1 Batch: 0 Loss: 2.3247387409210205 Accuracy: 10.09
Epoch: 1 Batch: 100 Loss: 1.5586031675338745 Accuracy: 64.34
Epoch: 1 Batch: 200 Loss: 0.8242834210395813 Accuracy: 81.4
Epoch: 1 Batch: 300 Loss: 0.6049119234085083 Accuracy: 86.83
Epoch: 1 Batch: 400 Loss: 0.4129831790924072 Accuracy: 87.14
Epoch: 1 Batch: 500 Loss: 0.4107397794723511 Accuracy: 89.9
Epoch: 1 Batch: 600 Loss: 0.36199185252189636 Accuracy: 90.48
Epoch: 1 Batch: 700 Loss: 0.42539575695991516 Accuracy: 91.31
Epoch: 1 Batch: 800 Loss: 0.3088320195674896 Accuracy: 91.61