Here is the Dataset I defined. The method __index()
is used to convert the index and seems slow.
class myset(Dataset):
def __init__(self,
u: torch.Tensor,
p: torch.Tensor,
length = (2, 3.2, 3.2),
num_points:int=216*101*257*257,
batch_size:int=50000,
device:torch.device=torch.device('cpu')):
self.N = num_points
self.num_p = u.shape[0]
self.num_t = u.shape[1]
self.Nx = u.shape[3]
self.Ny = u.shape[4]
self.num_f = self.Nx*self.Ny
self.num_tf = self.num_t*self.num_f
self.p = p.to(device)
xmesh = torch.linspace(0, length[1], self.Nx, device=device)
ymesh = torch.linspace(0, length[2], self.Ny, device=device)
xmesh, ymesh = torch.meshgrid(xmesh, ymesh, indexing='ij')
self.xmesh, self.ymesh = xmesh.reshape(-1,1), ymesh.reshape(-1,1)
self.t = torch.linspace(0, length[0], self.num_t, device=device).reshape(-1,1)
self.uicin = u[:,0].to(device)
self.num_initp = u.shape[0]
self.length = length
self.device = device
self.batch_size = batch_size
'''
0
+---------+
| |
3 | | 1
| |
+---------+
2
'''
u = torch.permute(u,(0,1,3,4,2))
self.u = u.reshape(-1, u.shape[4])
self.uicout = u[:,0].reshape(self.num_p, self.num_f, 2)#.to(device)
def __index(self, i):
idp = i//self.num_tf
idf = i%self.num_f
idt = (i-idp*self.num_tf)//self.num_f
return idp, idt, idf
def __getitem__(self, index):
idp, idt, idf = self.__index(index)
uin = self.uicin[idp]
pin = self.p[idp]
ulable = self.u[index]
x = self.xmesh[idf]
y = self.ymesh[idf]
t = self.t[idt]
return (x, y, t, uin, pin), ulable
def __len__(self):
return self.N