for example.
import time
import torch
import torch.nn as nn
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_features=200, out_features=1, bias=True)
)
self.net1 = nn.Parameter(torch.Tensor(1, 200))
def forward(self, feat):
feat.requires_grad_()
m = torch.rand(1, 200).to("cuda")
b = feat * m
return b
class SANNet(nn.Module):
def __init__(self):
super(SANNet, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_features=200, out_features=1, bias=True)
)
self.net1 = nn.Parameter(torch.Tensor(1, 200))
self.net2 = nn.ModuleList()
self.net2.append(Net())
def forward(self, feat):
feat.requires_grad_()
m = torch.rand(1, 200).to("cuda")
net:ModuleInterface = self.net[0]
c = net(m)
b = feat * c
return b
torch.set_default_dtype(torch.float64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SANNet()
model = model.to(device)
model.train()
opt = torch.optim.Adam(model.parameters(),lr=0.001)
f = nn.MSELoss()
aa = torch.rand(1,200)
data = torch.utils.data.DataLoader(aa, batch_size=1, shuffle=True, num_workers=1)
for i,x in enumerate(data) :
x = torch.rand(1, 200).to(device)
y = model(x)
loss = f(x, y)/10.0
loss.backward()
opt.step()
opt.zero_grad()
when I set torch.set_default_dtype(torch.float32),it can work.
but I set torch.set_default_dtype(torch.float64),it will report an error
Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32 notwithstanding
I can identify the location of the problem as
net:ModuleInterface = self.net[0]
c = net(m)