Hi,
I defined a new layer as the following code.
after the training of the network with the following layer, I would like to re-load the model to retrain.
My questions are:
- torch.load() and model.load_state_dict() will load thoss parameters in the layer, will those layer be tensor type or parameter type after the re-loading?
- if it is tensor type, the mask_weight in my following code couldn’t be optimized, right?
- how to make the reloaded model trainable?
class SMConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=1, stride=1):
super(SMConv2d, self).__init__()
self.mask_weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
nn.init.constant_(self.mask_weight, 1)
def compute_mask(self,temp):
self.mask = self.mask_weight*temp
# self.mask_weight = self.mask_weight * temp
return self.mask
def forward(self, x, temp=1):
masked_weight = self.compute_mask(temp)
out = F.conv2d(x, masked_weight, stride=self.stride, padding=self.padding)
return out