import os
import numpy as np
import torch
from torch.utils.data import Dataset
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
class TSPDataset(Dataset):
def __init__(self, size=50, num_samples=1e6, seed=None):
super(TSPDataset, self).__init__()
if seed is None:
seed = np.random.randint(123456789)
np.random.seed(seed)
torch.manual_seed(seed)
self.dataset = torch.rand((num_samples, 2, size))
self.dynamic = torch.zeros(num_samples, 1, size)
self.num_nodes = size
self.size = num_samples
def __len__(self):
return self.size
def __getitem__(self, idx):
# (static, dynamic, start_loc)
return (self.dataset[idx], self.dynamic[idx], self.dataset[idx, :, 0:1])
def update_mask(mask, dynamic, chosen_idx):
"""Marks the visited city, so it can't be selected a second time."""
mask.scatter_(1, chosen_idx.unsqueeze(1), 0)
return mask
I don’t quite understand the question in the title related to the posted code snippet.
Could you explain your use case as well as the problem you are facing a bit more?
1 Like
Ok , i 'am looking for create a TSP envirenement to solving by deep reinforcment learning
the class TSPDataset is my env_tsp, getItem is the methods to selecte the action in every step, i assume that is the key to fixed the start city in each predicted tour. The TSPdataset is a combination of dataset which contain the (xcoords, ycoords) and a dynamic vector contain the city index.
In briefly i whish to help me, to fixed the start city that is city with index 1 and used it as a rule for agent policy .
for example : i want to give me a permutatio n of city like this:
0,4,8,1,4,9…,0
0,1,3,9,5,7,…,0
thank’s