因为读入的两张图片含有不同数量的object,导致当我设置batchsize大于1时,总会提示错误,后来我按照网上的教程自定义了collate_fn方法,但是出现新的问题,代码无法显示shape,如果可以的话,请审阅一下我的代码并提出您宝贵的建议。代码如下:
import torch
import xml.etree.ElementTree as ET
import os.path as osp
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torchvision import transforms
import torch
import torch.utils.data as data
import numpy as np
import cv2
import xml # 标注是xml格式
try:
import xml.etree.cElementTree as ET # 解析xml的c语言版的模块
except ImportError:
import xml.etree.ElementTree as ET
VOC_CLASSES = {
# 记得加上背景类
“hat”,
“No_hat”,
}
把str映射为int
dict_classes = dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
print(dict_classes[‘aeroplane’])
transforms = transforms.Compose([
transforms.Resize((640, 640)),
transforms.ToTensor()
])
class ReadVOC(data.Dataset):
def init(self, root, transform = None):
print(“reading voc…”)
self.root = root
self.transform = transforms
self.img_idx = []
self.ano_idx = []
self.bbox = []
self.obj_name = [] # 类别
train_txt_path = self.root + "/ImageSets/Main/train.txt" # train这个文件夹里面数量太少 换掉
self.img_path = self.root + "/JPEGImages/"
self.ano_path = self.root + "/Annotations/"
# 首先读取txt文件进行训练集图片索引
train_txt = open(train_txt_path)
lines = train_txt.readlines()
for line in lines:
name = line.strip().split()[0] # 0意思是按右边的分割,1意思是按左边分割
# print(name) # name is in str type
self.img_idx.append(self.img_path + name + '.jpg')
self.ano_idx.append(self.ano_path + name + '.xml') # 最好是在这直接解析出bbox
def __getitem__(self, item):
# print("getitem...")
# print(self.img_idx[item])
img = cv2.imread(self.img_idx[item])
# img = Image.open(self.img_idx[item]).convert('RGB')
# img = self.transform(img)
height, width, channels = img.shape
# img = torch.tensor(img)
targrts = ET.parse(self.ano_idx[item]) # .getroot() # 运行时解析 逻辑更加清晰
res = np.empty((0, 5)) # 标注输出
# res = []
# find all obj in xml
for obj in targrts.iter("object"): # 便利物体
name = obj.find('name').text.strip()
class_idx = dict_classes[name]
bbox = obj.find('bndbox')
pts = ['xmin', 'ymin', 'xmax', 'ymax']
obj_bbox = []
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text) - 1
# cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height # scale height or width
obj_bbox.append(cur_pt)
obj_bbox.append(class_idx)
res = np.vstack((res, obj_bbox))
# 当前obj的所有bboxdf
# res.append(obj_bbox)
# res.append(class_idx)
img, res = self.data_trans(img, res)
return img, res
def __len__(self):
data_lenth = len(self.img_idx)
# print('data lenth is ', data_lenth)
return data_lenth
# 标注输入使用w h归一化的相对坐标
def data_trans(self, img_input, res_input):
# print("trans...")
goal_size = (640, 640)
# 在这时候,图像尺寸可以变化,只要目标不发生平移等等
img = cv2.resize(img_input, goal_size)
# pre-process input img
img = torch.from_numpy(img).permute(2, 0, 1).float()
# 把bbox转换成绝对坐标
# bbox = [bbox_input[0] * goal_size[0], bbox_input[1] * goal_size[1], bbox_input[2] * goal_size[0], bbox_input[3] * goal_size[1]]
# bbox = list(map(int, bbox))
# half_len = int(len(res_input)/2)
# print(half_len)
# print(len(bbox_input))
res = []
for i in range(len(res_input)):
res.append(torch.tensor(res_input[i]))
# res.append(res_input[2*i+1])
# return img, res
# res = torch.tensor(res_input[0])
return img, res
def collate_fn(batch):
# batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果
batch1 = list(zip(*batch))
half_batch = int(len(batch)/2)
for i in range(half_batch):
labels = torch.tensor(batch1[2*i], dtype=torch.int32)
texts = batch1[2*i+1]
del batch
return labels, texts
def collate_fn1(batch):
images = []
reses = []
for img, res in batch:
images.append(img)
reses.append(res)
# images = np.array(images)
# images = np.array(torch.tensor(img) for img in images)
images = torch.from_numpy(np.array(img).astype(dtype=np.float32)).type(torch.FloatTensor) #这里有一个错误,但是我不知道怎么改
# bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
return images, reses
if name == “main”:
# ReadVOC(root=‘D:/python_project_model/learn_pttorch_4_myproject/VOC2028’)
train_ = ReadVOC(root=‘D:/python_project_model/learn_pttorch_4_myproject/VOC2028’)
train_dalo = DataLoader(train_, 2, collate_fn=collate_fn1)
for i, (img, target) in enumerate(train_dalo):
img, target = img, target
print(img.shape)
print(target)
# img, target = train_[0] # 返回第一张图像及box和对应的类别
# print(img.shape)
# print(target)
#
#
# # 这里简单做一下可视化
# # 由于opencv读入是矩阵,而img现在是tensor,因此,首先将tensor转成numpy.array
# img_ = (img.numpy()).astype(np.uint8).transpose(1, 2, 0) # 注意由于图像像素分布0-255,所以转成uint8
# print(img_.shape)
# cv2.imshow('test', img_)
# cv2.waitKey(0)