论文必备 - RT-DETR训练前一键扩充数据集,支持9种扩充方法,支持图像和标签同步扩充
前言
在网络模型训练前,大量的数据集是必备的。而对于一些特殊场景中的特殊目标,大量的数据集往往难以获取,无法满足模型训练的基本需求。本文提供了 9种数据增强方法 ,脚本可一键运行, 生成扩充后的图像以及对应的标签文件 。
一、数据扩充的重要性
在计算机视觉模型训练中,对自建数据集进行数据扩充主要有以下几个重要原因:
-
增加数据量
-
缓解数据稀缺问题
在很多实际应用场景中,获取大量带有标注的高质量数据是非常困难的。而通过数据扩充技术,如旋转、翻转等操作,可以在原有少量数据的基础上生成更多的样本,让模型有更多的数据进行学习。
对于一些新兴的研究领域或特定的小众应用,扩充数据能够在一定程度上弥补数据不足的缺陷,使得模型训练不至于因为数据太少而无法有效学习到数据中的特征和模式。 -
提升模型性能
在有足够多的数据进行训练时,能够学习到更鲁棒的特征表示。更多的数据可以让模型更好地拟合复杂的决策边界。
-
-
增强模型的泛化能力
-
减少过拟合
当模型在训练数据上学习得过于具体,对训练数据中的噪声和特殊情况也进行了记忆,就会出现过拟合现象。数据扩充通过引入数据的变化,增加了数据的多样性。 -
适应不同场景的变化
实际的视觉数据在现实世界中会受到各种因素的影响,如光照变化、物体角度变化、遮挡等。通过数据扩充技术可以模拟这些变化。
-
二、9种增强方法实现脚本
完整代码如下:
# -*- coding: utf-8 -*-
"""
Created on 2023-04-01 9:08
@author: Fan yi ming
Func: 对于目标检测的数据增强[YOLO](特点是数据增强后标签也要更改)
review:常用的数据增强方式;
1.翻转:左右和上下翻转,随机翻转
2.随机裁剪,图像缩放
3.改变色调
4.添加噪声
注意: boxes的标签和坐标一个是int,一个是float,存放的时候要注意处理方式。
参考:https://github.com/REN-HT/Data-Augmentation/blob/main/data_augmentation.py
"""
import torch
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import random
random.seed(0)
class DataAugmentationOnDetection:
def __init__(self):
super(DataAugmentationOnDetection, self).__init__()
def resize_keep_ratio(self, image, boxes, target_size):
old_size = image.size[0:2]
ratio = min(float(target_size) / (old_size[i]) for i in range(len(old_size)))
new_size = tuple([int(i * ratio) for i in old_size])
return image.resize(new_size, Image.BILINEAR), boxes
def resizeDown_keep_ratio(self, image, boxes, target_size):
old_size = image.size[0:2]
ratio = min(float(target_size) / (old_size[i]) for i in range(len(old_size)))
ratio = min(ratio, 1)
new_size = tuple([int(i * ratio) for i in old_size])
return image.resize(new_size, Image.BILINEAR), boxes
def resize(self, img, boxes, size):
return img.resize((size, size), Image.BILINEAR), boxes
def random_flip_horizon(self, img, boxes, h_rate=1):
if np.random.random() < h_rate:
transform = transforms.RandomHorizontalFlip(p=1)
img = transform(img)
if len(boxes) > 0:
x = 1 - boxes[:, 1]
boxes[:, 1] = x
return img, boxes
def random_flip_vertical(self, img, boxes, v_rate=1):
if np.random.random() < v_rate:
transform = transforms.RandomVerticalFlip(p=1)
img = transform(img)
if len(boxes) > 0:
y = 1 - boxes[:, 2]
boxes[:, 2] = y
return img, boxes
def center_crop(self, img, boxes, target_size=None):
w, h = img.size
size = min(w, h)
if len(boxes) > 0:
label = boxes[:, 0].reshape([-1, 1])
x_, y_, w_, h_ = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
x1 = (w * x_ - 0.5 * w * w_).reshape([-1, 1])
y1 = (h * y_ - 0.5 * h * h_).reshape([-1, 1])
x2 = (w * x_ + 0.5 * w * w_).reshape([-1, 1])
y2 = (h * y_ + 0.5 * h * h_).reshape([-1, 1])
boxes_xyxy = torch.cat([x1, y1, x2, y2], dim=1)
if w > h:
boxes_xyxy[:, [0, 2]] = boxes_xyxy[:, [0, 2]] - (w - h) / 2
else:
boxes_xyxy[:, [1, 3]] = boxes_xyxy[:, [1, 3]] - (h - w) / 2
in_boundary = [i for i in range(boxes_xyxy.shape[0])]
for i in range(boxes_xyxy.shape[0]):
if (boxes_xyxy[i, 0] < 0 and boxes_xyxy[i, 2] < 0) or (boxes_xyxy[i, 0] > size and boxes_xyxy[i, 2] > size):
in_boundary.remove(i)
elif (boxes_xyxy[i, 1] < 0 and boxes_xyxy[i, 3] < 0) or (boxes_xyxy[i, 1] > size and boxes_xyxy[i, 3] > size):
in_boundary.append(i)
boxes_xyxy = boxes_xyxy[in_boundary]
boxes = boxes_xyxy.clamp(min=0, max=size).reshape([-1, 4])
label = label[in_boundary]
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
xc = ((x1 + x2) / (2 * size)).reshape([-1, 1])
yc = ((y1 + y2) / (2 * size)).reshape([-1, 1])
wc = ((x2 - x1) / size).reshape([-1, 1])
hc = ((y2 - y1) / size).reshape([-1, 1])
boxes = torch.cat([xc, yc, wc, hc], dim=1)
transform = transforms.CenterCrop(size)
img = transform(img)
if target_size:
img = img.resize((target_size, target_size), Image.BILINEAR)
if len(boxes) > 0:
return img, torch.cat([label.reshape([-1, 1]), boxes], dim=1)
else:
return img, boxes
def random_bright(self, img, u=120, p=1):
if np.random.random() < p:
alpha=np.random.uniform(-u, u)/255
img += alpha
img=img.clamp(min=0.0, max=1.0)
return img
def random_contrast(self, img, lower=0.5, upper=1.5, p=1):
if np.random.random() < p:
alpha=np.random.uniform(lower, upper)
img*=alpha
img=img.clamp(min=0, max=1.0)
return img
def random_saturation(self, img,lower=0.5, upper=1.5, p=1):
if np.random.random() < p:
alpha=np.random.uniform(lower, upper)
img[1]=img[1]*alpha
img[1]=img[1].clamp(min=0,max=1.0)
return img
def add_gasuss_noise(self, img, mean=0, std=0.1):
noise=torch.normal(mean,std,img.shape)
img+=noise
img=img.clamp(min=0, max=1.0)
return img
def add_salt_noise(self, img):
noise=torch.rand(img.shape)
alpha=np.random.random()/5 + 0.7
img[noise[:,:,:]>alpha]=1.0
return img
def add_pepper_noise(self, img):
noise=torch.rand(img.shape)
alpha=np.random.random()/5 + 0.7
img[noise[:, :, :]>alpha]=0
return img
def plot_pics(img, boxes):
plt.imshow(img)
label_colors = [(213, 110, 89)]
w, h = img.size
for i in range(boxes.shape[0]):
box = boxes[i, 1:]
xc, yc, wc, hc = box
x = w * xc - 0.5 * w * wc
y = h * yc - 0.5 * h * hc
box_w, box_h = w * wc, h * hc
plt.gca().add_patch(plt.Rectangle(xy=(x, y), width=box_w, height=box_h,
edgecolor=[c / 255 for c in label_colors[0]],
fill=False, linewidth=2))
plt.show()
def get_image_list(image_path):
files_list = []
for root, sub_dirs, files in os.walk(image_path):
for special_file in files:
special_file = special_file[0: len(special_file)]
files_list.append(special_file)
return files_list
def get_label_file(label_path, image_name):
fname = os.path.join(label_path, image_name[0: len(image_name)-4]+".txt")
data2 = []
if not os.path.exists(fname):
return data2
if os.path.getsize(fname) == 0:
return data2
else:
with open(fname, 'r', encoding='utf-8') as infile:
for line in infile:
data_line = line.strip("\n").split()
data2.append([float(i) for i in data_line])
return data2
def save_Yolo(img, boxes, save_path, prefix, image_name):
if not os.path.exists(save_path) or \
not os.path.exists(os.path.join(save_path, "images")):
os.makedirs(os.path.join(save_path, "images"))
os.makedirs(os.path.join(save_path, "labels"))
try:
img.save(os.path.join(save_path, "images", prefix + image_name))
with open(os.path.join(save_path, "labels", prefix + image_name[0:len(image_name)-4] + ".txt"), 'w', encoding="utf-8") as f:
if len(boxes) > 0:
for data in boxes:
str_in = ""
for i, a in enumerate(data):
if i == 0:
str_in += str(int(a))
else:
str_in += " " + str(float(a))
f.write(str_in + '\n')
except:
print("ERROR: ", image_name, " is bad.")
def runAugumentation(image_path, label_path, save_path):
image_list = get_image_list(image_path)
for image_name in image_list:
print("dealing: " + image_name)
img = Image.open(os.path.join(image_path, image_name))
boxes = get_label_file(label_path, image_name)
boxes = torch.tensor(boxes)
DAD = DataAugmentationOnDetection()
t_img, t_boxes = DAD.random_flip_horizon(img, boxes.clone())
save_Yolo(t_img, t_boxes, save_path, prefix="fh_", image_name=image_name)
t_img, t_boxes = DAD.random_flip_vertical(img, boxes.clone())
save_Yolo(t_img, t_boxes, save_path, prefix="fv_", image_name=image_name)
t_img, t_boxes = DAD.center_crop(img, boxes.clone(), 1024)
save_Yolo(t_img, t_boxes, save_path, prefix="cc_", image_name=image_name)
to_tensor = transforms.ToTensor()
to_image = transforms.ToPILImage()
img = to_tensor(img)
t_img, t_boxes = DAD.random_bright(img.clone()), boxes
save_Yolo(to_image(t_img), boxes, save_path, prefix="rb_", image_name=image_name)
t_img, t_boxes = DAD.random_contrast(img.clone()), boxes
save_Yolo(to_image(t_img), boxes, save_path, prefix="rc_", image_name=image_name)
t_img, t_boxes = DAD.random_saturation(img.clone()), boxes
save_Yolo(to_image(t_img), boxes, save_path, prefix="rs_", image_name=image_name)
t_img, t_boxes = DAD.add_gasuss_noise(img.clone()), boxes
save_Yolo(to_image(t_img), boxes, save_path, prefix="gn_", image_name=image_name)
t_img, t_boxes = DAD.add_salt_noise(img.clone()), boxes
save_Yolo(to_image(t_img), boxes, save_path, prefix="sn_", image_name=image_name)
t_img, t_boxes = DAD.add_pepper_noise(img.clone()), boxes
save_Yolo(to_image(t_img), boxes, save_path, prefix="pn_", image_name=image_name)
print("end: " + image_name)
if __name__ == '__main__':
image_path = r"figures/images"
label_path = r"figures/labels"
save_path = r"figures/results"
runAugumentation(image_path, label_path, save_path)
水平旋转
def random_flip_horizon(self, img, boxes, h_rate=1):
if np.random.random() < h_rate:
transform = transforms.RandomHorizontalFlip(p=1)
img = transform(img)
if len(boxes) > 0:
x = 1 - boxes[:, 1]
boxes[:, 1] = x
return img, boxes
竖直旋转
def random_flip_vertical(self, img, boxes, v_rate=1):
if np.random.random() < v_rate:
transform = transforms.RandomVerticalFlip(p=1)
img = transform(img)
if len(boxes) > 0:
y = 1 - boxes[:, 2]
boxes[:, 2] = y
return img, boxes
中心裁剪
def center_crop(self, img, boxes, target_size=None):
w, h = img.size
size = min(w, h)
if len(boxes) > 0:
label = boxes[:, 0].reshape([-1, 1])
x_, y_, w_, h_ = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
x1 = (w * x_ - 0.5 * w * w_).reshape([-1, 1])
y1 = (h * y_ - 0.5 * h * h_).reshape([-1, 1])
x2 = (w * x_ + 0.5 * w * w_).reshape([-1, 1])
y2 = (h * y_ + 0.5 * h * h_).reshape([-1, 1])
boxes_xyxy = torch.cat([x1, y1, x2, y2], dim=1)
if w > h:
boxes_xyxy[:, [0, 2]] = boxes_xyxy[:, [0, 2]] - (w - h) / 2
else:
boxes_xyxy[:, [1, 3]] = boxes_xyxy[:, [1, 3]] - (h - w) / 2
in_boundary = [i for i in range(boxes_xyxy.shape[0])]
for i in range(boxes_xyxy.shape[0]):
if (boxes_xyxy[i, 0] < 0 and boxes_xyxy[i, 2] < 0) or (boxes_xyxy[i, 0] > size and boxes_xyxy[i, 2] > size):
in_boundary.remove(i)
elif (boxes_xyxy[i, 1] < 0 and boxes_xyxy[i, 3] < 0) or (boxes_xyxy[i, 1] > size and boxes_xyxy[i, 3] > size):
in_boundary.append(i)
boxes_xyxy = boxes_xyxy[in_boundary]
boxes = boxes_xyxy.clamp(min=0, max=size).reshape([-1, 4])
label = label[in_boundary]
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
xc = ((x1 + x2) / (2 * size)).reshape([-1, 1])
yc = ((y1 + y2) / (2 * size)).reshape([-1, 1])
wc = ((x2 - x1) / size).reshape([-1, 1])
hc = ((y2 - y1) / size).reshape([-1, 1])
boxes = torch.cat([xc, yc, wc, hc], dim=1)
transform = transforms.CenterCrop(size)
img = transform(img)
if target_size:
img = img.resize((target_size, target_size), Image.BILINEAR)
if len(boxes) > 0:
return img, torch.cat([label.reshape([-1, 1]), boxes], dim=1)
else:
return img, boxes
亮度变换
def random_bright(self, img, u=120, p=1):
if np.random.random() < p:
alpha=np.random.uniform(-u, u)/255
img += alpha
img=img.clamp(min=0.0, max=1.0)
return img
增强对比度
def random_contrast(self, img, lower=0.5, upper=1.5, p=1):
if np.random.random() < p:
alpha=np.random.uniform(lower, upper)
img*=alpha
img=img.clamp(min=0, max=1.0)
return img
饱和度变换
def random_saturation(self, img,lower=0.5, upper=1.5, p=1):
if np.random.random() < p:
alpha=np.random.uniform(lower, upper)
img[1]=img[1]*alpha
img[1]=img[1].clamp(min=0,max=1.0)
return img
高斯噪声
def add_gasuss_noise(self, img, mean=0, std=0.1):
noise=torch.normal(mean,std,img.shape)
img+=noise
img=img.clamp(min=0, max=1.0)
return img
椒盐噪声(低)
def add_salt_noise(self, img):
noise=torch.rand(img.shape)
alpha=np.random.random()/5 + 0.7
img[noise[:,:,:]>alpha]=1.0
return img
椒盐噪声(高)
def add_pepper_noise(self, img):
noise=torch.rand(img.shape)
alpha=np.random.random()/5 + 0.7
img[noise[:, :, :]>alpha]=0
return img
三、调用步骤
在
RT-DETR
的项目目录中新建一个
python
文件,将上方代码复制到其中,只需要修改主函数中的
image_path
、
label_path
、
save_path
。
if __name__ == '__main__':
image_path = r"figures/images"
label_path = r"figures/labels"
save_path = r"figures/results"
image_path
设置成需要扩充的图像路径。
label_path
设置成扩充图像对应的标签路径。
save_path
设置成扩充后的图像和标签保存路径。
例如我这里只放了一张图像和其对应的标签文件。
四、增强结果
运行后生成了9种增强后的图像及其对应的标签文件。