数据加载和处理教程
译者:@distant1219
校对者:@bringtree
在解决机器学习问题时, 我们需要付出很多努力来准备数据, 为了使代码更具可读性, PyTorch提供了许多工具来使数据加载变得简单易行. 在本教程中, 我们将要学习如何对 一个重要的数据集进行加载、预处理数据增强.
为了运行本教程, 请确保以下包已经安装:
scikit-image: 用来读取图片和图像变换pandas: 更方便地解析csv文件
from __future__ import print_function, divisionimport osimport torchimport pandas as pdfrom skimage import io, transformimport numpy as npimport matplotlib.pyplot as pltfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transforms, utils# 忽略警告import warningswarnings.filterwarnings("ignore")plt.ion() # 交互模式
我们将要处理的是面部姿态的数据集, 其中一张人脸图像如下图一样被标注出来. 下面是一张脸的标注:
每张人脸图像上, 总共有68个不同的标注点被标记出来.
注解:
数据集中的csv文件记录着标注信息, 像下面这样:
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y0805personali01.jpg,27,83,27,98, ... 84,1341084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
让我们快速地读取CSV文件, 以(N,2)的数组形式获得标记点, 其中N表示标记点的个数.
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')n = 65img_name = landmarks_frame.iloc[n, 0]landmarks = landmarks_frame.iloc[n, 1:].as_matrix()landmarks = landmarks.astype('float').reshape(-1, 2)print('Image name: {}'.format(img_name))print('Landmarks shape: {}'.format(landmarks.shape))print('First 4 Landmarks: {}'.format(landmarks[:4]))
我们写一个函数来显示一张图片和它的标记点, 然后用这个函数来显示一个样本.
def show_landmarks(image, landmarks):"""显示带标记点的图片"""plt.imshow(image)plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')plt.pause(0.001) # 暂停一下, 使plots更新plt.figure()show_landmarks(io.imread(os.path.join('faces/', img_name)),landmarks)plt.show()
Dataset类
torch.utils.data.Dataset 是一个表示数据集的抽象类. 你自己的数据集一般应该继承Dataset, 并且重写下面的方法:
__len__使用len(dataset)可以返回数据集的大小__getitem__支持索引, 以便于使用dataset[i]可以 获取第i个样本(0索引)
我们为我们的人脸数据集创建一个数据集类. 我们使用 __init__方法来读取csv文件, 使用 __getitem__读取图片. 这样可以使内存高效利用, 因为我们并不需要在内存中一次存储所有图片, 而是按照需要读取.
数据集的一个样例是一个{'image': image, 'landmarks': landmarks}样的字典. 数据集类中有一个可选的参数 transform 这样可以对数据集做任何需要的处理. 我们将在下节看到 transform 的用处.
class FaceLandmarksDataset(Dataset):"""人脸标记数据集"""def __init__(self, csv_file, root_dir, transform=None):"""Args:csv_file (string): 带有标记点的csv文件路径root_dir (string): 图片路径transform (callable, 可选):可选择进行的图像变换"""self.landmarks_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.landmarks_frame)def __getitem__(self, idx):img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])image = io.imread(img_name)landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()landmarks = landmarks.astype('float').reshape(-1, 2)sample = {'image': image, 'landmarks': landmarks}if self.transform:sample = self.transform(sample)return sample
让我们实例化这个类 并且迭代所有的数据样本. 我们将打印前4个样本, 并显示它们的标记点.
face_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',root_dir='faces/')fig = plt.figure()for i in range(len(face_dataset)):sample = face_dataset[i]print(i, sample['image'].shape, sample['landmarks'].shape)ax = plt.subplot(1, 4, i + 1)plt.tight_layout()ax.set_title('Sample #{}'.format(i))ax.axis('off')show_landmarks(**sample)if i == 3:plt.show()break
Transforms
我们可以看到上面输出的样例中的图像并不是同一尺寸的图片. 大多数神经网络需要输入 一个固定大小的图像, 因此我们需要写代码来处理. 让我们创建三个transform操作:
Rescale: 修改图片尺寸RandomCrop: 随机裁切图片, 这是数据增强的方法ToTensor: 将numpy格式的图片转为torch格式的图片(我们需要交换坐标轴)
我们不将它们写成简单的函数, 而是写成可以调用的类, 这样transform的参数不需要每次都传递 如果需要的话, 我们只需实现 __call__ 方法和__init__ 方法.之后我们可以像下面这 样使用transform:
tsfm = Transform(params)transformed_sample = tsfm(sample)
观察下面这些变换如何同时对图像和标记点改变的.
class Rescale(object):"""按照给定尺寸更改一个图像的尺寸Args:output_size (tuple or int): 要求输出的尺寸. 如果是个元组类型, 输出和output_size匹配. 如果时int类型,图片的短边和output_size匹配, 图片的长宽比保持不变."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]if isinstance(self.output_size, int):if h > w:new_h, new_w = self.output_size * h / w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size * w / helse:new_h, new_w = self.output_sizenew_h, new_w = int(new_h), int(new_w)img = transform.resize(image, (new_h, new_w))# 对于标记点, h和w需要交换位置, 因为对于图像, x和y分别时第1维和第0维landmarks = landmarks * [new_w / w, new_h / h]return {'image': img, 'landmarks': landmarks}class RandomCrop(object):"""随机裁剪图片Args:output_size (tuple or int): 期望输出的尺寸, 如果时int类型, 裁切成正方形."""def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = (output_size, output_size)else:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h - new_h)left = np.random.randint(0, w - new_w)image = image[top: top + new_h,left: left + new_w]landmarks = landmarks - [left, top]return {'image': image, 'landmarks': landmarks}class ToTensor(object):"""将ndarrays的样本转化为Tensors的样本"""def __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']# 交换颜色通道, 因为# numpy图片: H x W x C# torch图片 : C X H X Wimage = image.transpose((2, 0, 1))return {'image': torch.from_numpy(image),'landmarks': torch.from_numpy(landmarks)}
Compose transforms
现在我们就将 transform 应用在一个样本上.
如果我们想将图片的短边变为256像素, 并且随后随机裁切成224像素的正方形. i.e, 我们可以组合Rescale和RandomCrop变换. torchvision.transforms.Compose 就是一个可以做这样一个组合的可调用的类.
scale = Rescale(256)crop = RandomCrop(128)composed = transforms.Compose([Rescale(256),RandomCrop(224)])# 对每个样本进行上面的每一个操作.fig = plt.figure()sample = face_dataset[65]for i, tsfrm in enumerate([scale, crop, composed]):transformed_sample = tsfrm(sample)ax = plt.subplot(1, 3, i + 1)plt.tight_layout()ax.set_title(type(tsfrm).__name__)show_landmarks(**transformed_sample)plt.show()
迭代整个数据集
让我们进行所有操作, 来创建结合了图像变换的数据集 总之, 每次迭代的数据:
- 从文件中读取图像
- 对所读的图像上应用变换
- 其中一个变换是随机的, 所以可以增强数据样本
我们可以像之前一样用 for i in range 循环从已创建的数据集中迭代
transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',root_dir='faces/',transform=transforms.Compose([Rescale(256),RandomCrop(224),ToTensor()]))for i in range(len(transformed_dataset)):sample = transformed_dataset[i]print(i, sample['image'].size(), sample['landmarks'].size())if i == 3:break
然而我们用简单的for循环来迭代整个数据集会丢失很多特点, 特别地, 我们会丢失:
- 批读取数据
- 打乱数据顺序
- 使用
multiprocessing并行加载数据
torch.utils.data.DataLoader 是提供了所有上述特点的迭代器. 下面使用的参数应该很清晰. 其中一个有趣的参数是collate_fn. 你可以使用collate_fn来指定如何精确地读取一批的样本. 然而, 默认的collate在大部分的情况下都表现得很好
dataloader = DataLoader(transformed_dataset, batch_size=4,shuffle=True, num_workers=4)# 定义一个函数来查看某个 batch 的数据样本图片和标记点def show_landmarks_batch(sample_batched):"""显示指定 batch 的数据样本的图片和标记点"""images_batch, landmarks_batch = \sample_batched['image'], sample_batched['landmarks']batch_size = len(images_batch)im_size = images_batch.size(2)grid = utils.make_grid(images_batch)plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,landmarks_batch[i, :, 1].numpy(),s=10, marker='.', c='r')plt.title('Batch from dataloader')for i_batch, sample_batched in enumerate(dataloader):print(i_batch, sample_batched['image'].size(),sample_batched['landmarks'].size())# 观察到第四批数据时停止if i_batch == 3:plt.figure()show_landmarks_batch(sample_batched)plt.axis('off')plt.ioff()plt.show()break
后记: torchvision
在这个教程中, 我们学习了如何写和使用数据集, 图像变换和dataloder. torchvision 提供了常用的数据集和图像变换, 或许你甚至不必写自定义的类和变换. 在torchvision中一个最经常用的数据集是ImageFolder. 它要求数据按下面的形式存放:
root/ants/xxx.pngroot/ants/xxy.jpegroot/ants/xxz.png...root/bees/123.jpgroot/bees/nsdf3.pngroot/bees/asd932_.png
‘ants’, ‘bees’ 等是图像的类标. 同样, PIL.Image 中出现的一般的图像变换像 RandomHorizontalFlip, Scale 也是可以使用的. 你可以像下面这样用这些函数来写dataloader:
import torchfrom torchvision import transforms, datasetsdata_transform = transforms.Compose([transforms.RandomSizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',transform=data_transform)dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4, shuffle=True,num_workers=4)
关于训练代码的例子, 请看 迁移学习教程.

