I have to convert the following code from Lua/Torch to PyTorch
function GramMatrix()
local net = nn.Sequential()
net:add(nn.View(-1):setNumInputDims(2))
local concat = nn.ConcatTable()
concat:add(nn.Identity())
concat:add(nn.Identity())
net:add(concat)
net:add(nn.MM(false, true))
return net
end
Till now I have tried this->
class GramMatrix(nn.Module):
def forward(self, input):
a, b, c, d = input.size() # a=batch size(=1)
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
But the output of GramMatrix().forward() results in require_grad = True, which later on causes problem with nn.MSELoss().forward()
What should I do?