I am performing classification to identify which phenotype does the person in the image belong to.
In all there are eight classes
My dataset is organized as follows
Images
Character_class(contains .txt files,each txt file tells us about which class does the image belong to].The label looks like this
m_la01
No of images 800
No of labels corresponding to the images 800
As such ,given my current dataset which is shown below,what changes should I make to my dataloader so that I can train the classifier on it using standard classification loss(crossentropy).Do i need to encode the label in one hot form and return that as a tensor?
My custom dataset
class BlenderPoseDataset(Dataset):
def __init__(self,paths,batch_size=16):
self.img_dir='blender_data_1/images_blender/*'
self.pose_files=paths
self.batch_size=batch_size
self.transforms=None
self.image_shape=(224,224)
self.data=[]
self.min_data=np.load('char_min_data.npy')
self.max_data=np.load('char_max_data.npy')
img_list=glob.glob(self.img_dir)
for img_path in img_list:
for file in self.pose_files:
self.data.append([img_path,file])
def __len__(self):
return len(self.data)
def image_process(self,rgb_img):
rgb_img=np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
return rgb_img
def __getitem__(self,idx):
if self.transforms is not None:
self.image=self.transform(self.image)
img_path,label_path=self.data[idx]
img=cv2.imread(img_path)
img=cv2.resize(img,self.image_shape)[:,:,::-1].astype(np.float32)
img=self.image_process(img)
img=torch.from_numpy(img).float()
with open(label_path,'r') as f:
label=f.read(label_path)
return img,label
dataset=Dataset(labels_paths,batch_size=32)
I have also implemented my own representation of one hot encoding that converts the class label to one hot and returns them,which can be later on used during training.Please let me know if that approach sounds reaosnable
label_arr=np.zeros(8)
classes=['f_af01','f_as01','f_ca01','f_la01','m_af01','m_as01','m_ca01','m_la01']
with open(file_name,'r') as f:
label=f.readline()
print(label)
for i in range(len(label_arr)):
if label==classes[i]:
label_arr[i]=1