Hi, in the ModuleList, I only set the requires_grad
of the Haar transform to False. But the whole ModuleList’s params are non-trainable. And I can’t find the reason.
class RNVP(nn.Module):
def __init__(self, dims_in, message_length, diffusion_length=256, down_num=3, block_num=[4, 4, 6]):
super(RNVP, self).__init__()
# diffussion block
self.in_channel = 3
self.dims_in = dims_in
self.H = dims_in[0][1]
self.W = dims_in[0][2]
self.diffusion_length = diffusion_length
self.diffusion_size = int(self.diffusion_length ** 0.5)
self.linear1 = nn.Linear(message_length, self.diffusion_length)
self.linear2 = nn.Linear(message_length, self.diffusion_length)
self.linear3 = nn.Linear(message_length, self.diffusion_length)
self.msg_up1 = ExpandNet(1, 1, 3)
self.msg_up2 = ExpandNet(1, 1, 3)
self.msg_up3 = ExpandNet(1, 1, 3)
self.linear_rev1 = nn.Linear(self.H*self.W, message_length)
self.linear_rev2 = nn.Linear(self.H*self.W, message_length)
self.linear_rev3 = nn.Linear(self.H*self.W, message_length)
self.HaarDown = HaarDownsampling(dims_in)
# RNVP in UNet
operations = []
# down
current_dims = dims_in
for i in range(down_num):
if i != 0:
b = HaarDownsampling(current_dims)
operations.append(b)
current_dims[0][0] = current_dims[0][0] * 4
current_dims[0][1] = current_dims[0][1] // 2
current_dims[0][2] = current_dims[0][2] // 2
else:
current_dims[0][0] = current_dims[0][0] * 4 * 2
current_dims[0][1] = current_dims[0][1] // 2
current_dims[0][2] = current_dims[0][2] // 2
for j in range(block_num[i]):
b = RNVPCouplingBlock(current_dims, subnet_constructor=ResidualDenseBlock, clamp=1.0)
# b = InvBlock()
operations.append(b)
# up
block_num = block_num[:-1][::-1]
block_num.append(0)
for i in range(down_num):
if i != 2:
b = HaarUpsampling(current_dims)
operations.append(b)
current_dims[0][0] = current_dims[0][0] // 4
current_dims[0][1] = current_dims[0][1] * 2
current_dims[0][2] = current_dims[0][2] * 2
for j in range(block_num[i]):
b = RNVPCouplingBlock(current_dims, subnet_constructor=ResidualDenseBlock, clamp=1.0)
# b = InvBlock()
operations.append(b)
self.operations = nn.ModuleList(operations)
======================================================================
Layer (type:depth-idx) Param #
======================================================================
├─Model: 1-1 --
| └─RNVP: 2-1 --
| | └─Linear: 3-1 7,936
| | └─Linear: 3-2 7,936
| | └─Linear: 3-3 7,936
| | └─ExpandNet: 3-4 21
| | └─ExpandNet: 3-5 21
| | └─ExpandNet: 3-6 21
| | └─Linear: 3-7 491,550
| | └─Linear: 3-8 491,550
| | └─Linear: 3-9 491,550
| | └─HaarDownsampling: 3-10 (48)
| | └─ModuleList: 3-11 (28,595,840)
======================================================================
Total params: 30,094,409
Trainable params: 1,498,521
Non-trainable params: 28,595,888
======================================================================