I got this error when I was doing an image classification task, The problem persists even if I lower the batch to 64
net.train()
# train and update
optimizer = torch.optim.Adam(net.parameters(), lr = self.lr)
privacy_engine = PrivacyEngine()
net, optimizer, self.ldr_train = privacy_engine.make_private_with_epsilon(
module=net,
optimizer=optimizer,
data_loader=self.ldr_train,
target_epsilon=2, target_delta=1e-5,
epochs=50,
max_grad_norm=1.0,
)
epoch_acc = []
epoch_loss = []
for iter in range(self.local_ep):
batch_acc = []
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.ldr_train):
images, labels = images.to(self.device), labels.to(self.device)
net.zero_grad()
# ---------forward prop-------------
fx = net(images)
# calculate loss
loss = self.loss_func(fx, labels)
# calculate accuracy
acc = calculate_accuracy(fx, labels)
# --------backward prop--------------
loss.backward()
optimizer.step()
batch_loss.append(loss.item())
batch_acc.append(acc.item())