I tried to run the following code:
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel):
super(ResBlock, self).__init__()
self.bn = nn.BatchNorm2d(in_channel)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1)
def forward(self, x):
identity = x
out = self.bn(x)
out = self.relu(out)
out = self.relu(self.conv(out))
out += identity
return out
class Net(nn.Module):
def __init__(self, h, w):
super().__init__()
self.h = h
self.w = w
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=2, stride=1)
self.res1 = ResBlock(in_channel=16, out_channel=32)
self.res2 = ResBlock(in_channel=32, out_channel=64)
self.fc1 = nn.Linear(64*w*h, 4*w*h-2*w-2*h)
self.fc2 = nn.Linear(4*w*h-2*w-2*h, 2*w*h-w-h)
def forward(self, x):
x = x.unsqueeze(0)
x = x.unsqueeze(0)
out = F.relu(self.conv1(x))
out = self.res1(out)
out = self.res2(out)
out = out.view(-1, 64*self.w*self.h)
return out
net = Net(3, 3)
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
out = net(a)
but got the following error:
If I change the last line to:
out = net(a.to(torch.int32))
I will get the following error:
Any help will be appreciated!