class Net1(nn.Module):
def __init__(self):
super(Net1, self).__init__()
self.linear = nn.Linear(3072, 10)
def forward(self, x):
out = x.view(x.size(0), -1)
out = self.linear(out)
return out
class lin(nn.Module):
def __init__(self, wt, bs, inp, output):
super(lin, self).__init__()
self.linear = nn.Linear(inp, output)
self.linear.weight = nn.parameter.Parameter(wt)
self.linear.bias = nn.parameter.Parameter(bs)
def forward(self, input):
out = input.view(input.size(0), -1)
out = self.linear(out)
return out
class Net_test(nn.Module):
def __init__(self,paramsel):
super(Net_test, self).__init__()
self.lnr = lin(paramsel['linear.weight'], paramsel['linear.bias'], 3072, 10)
def forward(self, x):
print("Inside net test")
out = x.view(x.size(0), -1)
out = self.lnr(out)
print(out.requires_grad)
return out
classf = Net1()
classf = classf.cuda()
paramNettest = {}
theta = torch.randn(1, requires_grad=True).cuda()
for i, data in enumerate(train_loader, 0):
inputs, labels = data
inputs, labels = inputs.cuda(), Variable(labels.cuda())
for name, param in classf.named_parameters():
paramNettest[name] = param * theta
classft = Net_test(paramNettest)
classft = classft.cuda()
op2 = classft(inputs)
loss =F.cross_entropy(op2,labels,reduction='mean')
print(grad(loss,theta))