I’m a little lost on how it would be possible to perform weight sharing in pytorch. In my case I would like to do the following:
Essentially I would like to reuse weights q.data.weights and weights w2 concatenated together in a loop. q.data.weights are the weights for conv2d layer q as well as being used in the loop, while weights w2 would ONLY be used in the conv2d layer in the loop (however, since it would be repeated n times, technically weights w2 would be shared as well).
class Model(nn.Module):
def __init__(self, config):
super(Model, self).__init__()
self.config = config
self.w2 = Variable(torch.randn(10, 1, 3, 3), requires_grad=True)
self.h = nn.Conv2d(in_channels=2,
out_channels=150,
kernel_size=(3, 3),
stride=1, padding=1)
self.r = nn.Conv2d(in_channels=150,
out_channels=1,
kernel_size=(1, 1),
stride=1, padding=1)
self.q = nn.Conv2d(in_channels=1,
out_channels=10,
kernel_size=(3, 3),
stride=1, padding=1)
def forward(self, x):
h = self.h(x)
r = self.r(h)
q = self.q(r)
v = torch.max(q, dim=1)[0]
for i in range(n):
rv = torch.cat([r, v], 1)
Conv2d(rv, weights=[self.q.weight, self.w2]) # This is obviously not valid torch code
In Theano it could be accomplished as follows:
# Helper function
def conv2D_keep_shape(x, w, subsample=(1, 1)):
# crop output to same size as input
fs = T.shape(w)[2] - 1 # this is the filter size minus 1
ims = T.shape(x)[2] # this is the image size
return theano.sandbox.cuda.dnn.dnn_conv(img=x,
kerns=w,
border_mode='full',
subsample=subsample,
)[:, :, fs/2:ims+fs/2, fs/2:ims+fs/2]
# Weights
self.w = theano.shared((np.random.randn(1, 150, 1, 1)).astype(theano.config.floatX))
self.w1 = theano.shared((np.random.randn(10, 1, 3, 3)).astype(theano.config.floatX))
self.w2 = theano.shared((np.random.randn(10, 1, 3, 3)).astype(theano.config.floatX))
# Model
self.r = conv2D_keep_shape(input, self.w)
self.q = conv2D_keep_shape(self.r, self.w1)
self.v = T.max(self.q, axis=1, keepdims=True)
for i in range(n):
self.q = conv2D_keep_shape(T.concatenate([self.r, self.v], axis=1), T.concatenate([self.w1, self.w2], axis=1))
Is this possible to achieve in pytorch? Please also let me know if you need more explination in case the above is not clear.