Thank you for your reply! I considered your suggestion and I found that two new inputs caused the hang. But I still son;t know why this happened. The code include the two inputs (ltw, rtw )is down below.
class StateCVRP(NamedTuple):
# Fixed input
coords: torch.Tensor
demand: torch.Tensor
ltw: torch.Tensor
rtw: torch.Tensor
ids: torch.Tensor # Keeps track of original fixed data index of rows
# State
prev_a: torch.Tensor
used_capacity: torch.Tensor
visited_: torch.Tensor # Keeps track of tasks that have been visited
lengths: torch.Tensor
cur_index:torch.Tensor
i: torch.Tensor # Keeps track of step
VEHICLE_CAPACITY = 1.0 # Hardcoded
VEHICLE_V = 3
def initialize(input, visited_dtype=torch.uint8):
demand = input['demand']
depot = input['depot']
loc_id = input['loc_id']
ltw = input['ltw']
rtw = input['rtw']
batch_size, n_loc, _ = loc_id.size()
return StateCVRP(
coords=torch.cat((depot[:, None, :], loc_id), -2).type(torch.long),
demand=demand,
ltw=ltw,
rtw=rtw,
ids=torch.arange(batch_size, dtype=torch.int64, device=loc_id.device)[:, None], # Add steps dimension
prev_a=torch.zeros([batch_size, 10], dtype=torch.long, device=loc_id.device),
used_capacity=demand.new_zeros(batch_size, 10, 1),
cur_index=input['depot'][:,None,:].type(torch.long), # 1024,1,2
visited_=(
torch.zeros(
batch_size, 1, n_loc + 1,
dtype=torch.uint8, device=loc_id.device
)
if visited_dtype == torch.uint8
else torch.zeros(batch_size, 1, (n_loc + 63) // 64, dtype=torch.int64, device=loc_id.device)
),
lengths=torch.zeros(batch_size, 1, device=loc_id.device),
i=torch.zeros(1, dtype=torch.int64, device=loc_id.device) # Vector with length num_steps
)
def make_instance(args):
depot = args['depot']
loc_id = args['loc_id']
demand = args['demand']
ltw = args['ltw']
rtw = args['rtw']
capacity = args['CAPACITIES']
grid_size = 1
return {
'loc_id': torch.tensor(loc_id, dtype=torch.float) / grid_size,
'demand': torch.tensor(demand, dtype=torch.float) / capacity,
'depot': torch.tensor(depot, dtype=torch.float) / grid_size,
'ltw' : torch.tensor(ltw, dtype=torch.float) / grid_size,
'rtw' : torch.tensor(rtw, dtype=torch.float) / grid_size,
}
class VRPDataset(Dataset):
def __init__(self, filename=None, size=50, num_samples=1000000, offset=0, distribution=None):
super(VRPDataset, self).__init__()
self.data_set = []
if filename is not None:
assert os.path.splitext(filename)[1] == '.pkl'
with open(filename, 'rb') as f:
data = pickle.load(f)
a = data[offset:offset+num_samples]
self.data = [make_instance(args) for args in data[offset:offset+num_samples]]
else:
CAPACITIES = {
10: 20.,
20: 30.,
50: 40.,
100: 50.
}
self.data = [
{
'depot': torch.full([2], 2, dtype=torch.float),
'loc_id': torch.randint(0, 57, (size,2)),
'demand': (torch.FloatTensor(size).uniform_(0, 9).int() + 1).float() / CAPACITIES[size],
'ltw': torch.randint(0, 100, (size,)),
'rtw': torch.randint(500, 600, (size,)),
}
for i in range(num_samples)
]
self.size = len(self.data)
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx]
def _init_embed(self, input):
features = ('demand', 'ltw', 'rtw')
return torch.cat(
(
self.init_embed_depot(input['depot'])[:, None, :],
self.init_embed(torch.cat((
input['loc_id'],
*(input[feat][:, :, None] for feat in features)
), -1))
),
1
)