Hi,
When I run the codes, I found the output of the model get a False requires_grad.
import torch
from torch import nn
from torch.autograd import Function
class RouteNet(torch.nn.Module):
def __init__(self):
super(RouteNet, self).__init__()
self.encode = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, 9, 1, 4),
torch.nn.MaxPool2d(2, 2, 0, 1),
torch.nn.Conv2d(32, 64, 7, 1, 3),
torch.nn.MaxPool2d(2, 2, 0, 1),
torch.nn.Conv2d(64, 32, 9, 1, 4))
self.decode = torch.nn.Sequential(
torch.nn.Conv2d(32, 32, 7, 1, 3),
torch.nn.ConvTranspose2d(32, 16, 9, 2, 4, 1),
torch.nn.Conv2d(16, 16, 5, 1, 2),
torch.nn.ConvTranspose2d(16, 4, 5, 2, 2, 1),
torch.nn.Conv2d(4, 1, 3, 1, 1))
def forward(self, x):
encode_out = self.encode(x)
# res = conv_out.view(conv_out.size(0), -1)
out = self.decode(encode_out)
return out
class CongestionComputeFunction(Function):
@staticmethod
def forward(ctx):
net = RouteNet()
net.load_state_dict(torch.load("./net_H_epoch_39_batch_15_param.pkl"), strict=False)
inp = torch.rand(2, 3, 100, 100, requires_grad=True)
out = net(inp)
print(out.requires_grad) #False
@staticmethod
def backward(ctx):
return None
class CongestionCompute(nn.Module):
def __init__(self):
super(CongestionCompute, self).__init__()
def forward(self):
return CongestionComputeFunction.apply()
if __name__=='__main__':
congestion = CongestionCompute()
congestion.forward()
I want to get the grad in this model.
Looking forward to your reply.