Hey everyone,
I attempt to train a model using Opacus, the layers’ weight of the model share the same basis.
Could anyone help me to figure out why I get these errors?
Here’s my source code
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from opacus import PrivacyEngine
import torch.nn as nn
from data import get_data
from train_utils import get_device, train, test
import math
device = get_device()
class testModule(nn.Module):
def __init__(self):
super(testModule, self).__init__()
self.fc1 = nn.Linear(784, 10, bias=True)
self.fc2 = nn.Linear(10, 10, bias=True)
self.fc3 = nn.Linear(10, 10, bias=True)
self.fc1_t0 = self.fc1.weight.clone().to(device)
self.fc2_t0 = self.fc2.weight.clone().to(device)
self.fc3_t0 = self.fc3.weight.clone().to(device)
del self.fc1.weight
del self.fc2.weight
del self.fc3.weight
Dim = 200
self.fc2_base_weights = Parameter(torch.zeros((Dim, 10)))
self.rand_matrix_1 = torch.randn((784, Dim), requires_grad=False).to(device)
self.rand_matrix_2 = torch.randn((10, Dim), requires_grad=False).to(device)
self.rand_matrix_3 = torch.randn((10, Dim), requires_grad=False).to(device)
def forward(self, x):
x = x.reshape(x.size(0), -1)
#breakpoint()
self.fc1.weight = self.fc1_t0+torch.matmul(self.rand_matrix_1, self.fc2_base_weights.clone()).transpose(0,1)
self.fc2.weight = self.fc2_t0+torch.matmul(self.rand_matrix_2, self.fc2_base_weights.clone())
self.fc3.weight = self.fc3_t0+torch.matmul(self.rand_matrix_3, self.fc2_base_weights.clone())
x = self.fc1(x)
x = nn.ReLU()(x)
x = self.fc2(x)
x = nn.ReLU()(x)
x = self.fc3(x)
return x
train_data, test_data = get_data("mnist")
model = testModule()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=128,
shuffle=False,
num_workers=3,
pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data,
batch_size=1024,
shuffle=False,
num_workers=3,
pin_memory=True)
privacy_engine = PrivacyEngine(
model,
sample_rate=1024 / len(train_data),
noise_multiplier=3,
max_grad_norm=0.1
)
scattering = None
#privacy_engine.attach(optimizer)
for epoch in range(100):
print('epoch ', epoch+1)
train_loss, train_acc = train(model=model,
train_loader=train_loader,
optimizer=optimizer,
scattering=scattering)
print('train acc: ',train_acc, 'train loss: ',train_loss)
test_loss, test_acc = test(model=model,
test_loader=test_loader,
scattering=scattering)
print('test acc: ',test_acc, 'test loss: ',test_loss)
and the error msg
Thanks in advance