class Mymodel(nn.Module):
def __init__(self,
base_encoder,
num_classes=21,
moving_average_decay = 0.99,
augment_fn = None,
augment_fn2 = None,
use_momentum=True):
super(Mymodel, self).__init__()
self.m = moving_average_decay
self.online_encoder = base_encoder
self.target_encoder = copy.deepcopy(self.online_encoder)
for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
print("????")
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
print(f"---{param_k.data}---{param_q.data}")
def forward(self, x, rotate_type=0):
assert rotate_type in [0, 1, 2, 3]
# for name, i in self.online_encoder.named_parameters():
# print(name)
online_pred = self.online_encoder(TF.rotate(x, rotate_type * 90))
with torch.no_grad():
self._momentum_update_key_encoder()
target_pred = TF.rotate(self.target_encoder(x), rotate_type * 90)
return online_pred, target_pred
When I tried to write code like this.
I can’t set the parameters in self.online_encoder, and I found that list(self.online_encoder.parameters()) is empty list. (in forward and _momentum_update_key_encoder)
I used nn.Dataparallel to train my model on multi-gpu machine, it calculated the gradient on the main gpu, so my function in class to update the parameters is incorrect.
When I choose to use on gpu or DDP, it works well.