I want to train a convolutional neural network (CNN) in PyTorch to predict frequency spectrum data related to an input image. Rather than assigning one label to each image (Dog, Cat, Car, Airplane, etc.), I would like to assign a matrix of labels (one label per frequency) to each image. In PyTorch, how do I assign a matrix of data as a label to each input image in my dataset? I have been trying to do this using ImageFolder. Thanks!
Could you explain a bit, how this matrix would look like?
Are you passing these frequency images as your input to the model and would like to assign a label for each “pixel” (similar so a segmentation use case)?
I am trying to train a neural network to simulate the frequency response of transmission through a physical structure. I have a dataset of images where each image of a structure has an associated transmission response that differs over frequency.
Rather than having the neural network output the transmission response for a single frequency, I would like to have a fully-connected layer with N outputs that correspond to the transmission response for N frequencies contained in the spectrum. My problem is that ImageFolder and DataFolder default to use the folder structure to assign a single label to each image. How do I assign multiple labels to each image (where the label is the transmission response across N frequencies)? Thanks!
Thanks for the information.
The easiest way would be to create a custom Dataset
and load the data and target as necessary.
I assume you have the input data stored as images and the targets as some kind of arrays?
Here is a small dummy example of a custom Dataset
:
class MyDataset(Dataset):
def __init__(self, image_paths, target_paths, transform=None):
self.image_paths = image_paths
self.target_paths = target_paths
self.transform = transform
def __getitem__(self, index):
x = Image.open(self.image_paths[index])
y = torch.from_numpy(np.load(self.target_paths[index]))
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.image_paths)
Thank you very much! This is very useful and solved my problem.