RT-DETR训练前的准备,将数据集划分成训练集、测试集验证集(附完整脚本及使用说明)
前言
在计算机视觉领域,当我们着手训练一个模型时, 数据集的处理是至关重要的第一步 。
其中,将数据集划分为 训练集 、 验证集 和 测试集 更是一项基础性且关键的操作。这篇博客分析了这一划分背后的原因、其重要意义以及如何通过代码实现,帮助大家更好地理解和运用这一技巧, 保证模型训练的效果与可靠性 。
一、为什么要对数据集进行划分
在机器学习和计算机视觉任务中,模型的目标是学习数据中的模式与规律,进而对未知数据进行准确预测。但如果我们直接将所有可用数据一股脑用于训练,模型会记住这些数据的所有细节,包括其中的噪声与异常值,这就导致模型在面对新的、真实世界中的数据时表现不佳,也就是出现 过拟合 现象。
通过将数据集划分,我们能够模拟出模型在实际应用中会遇到的不同场景。
- 训练集 用于让模型学习数据特征,就像是学生在课堂上学习知识;
- 验证集 则充当阶段性考核,在模型训练过程中,用来评估模型当前的性能,及时发现模型是否在训练集上过拟合,调整训练方向;
- 测试集 只有在模型训练完成后,用其评估模型的泛化能力,即模型对全新、从未见过的数据的处理能力,以此判断模型能否真正用于实际场景。
二、划分的作用
- 准确评估模型性能 :验证集给予我们在训练过程中调整模型的依据,避免盲目训练。测试集提供无偏差的模型性能衡量,让我们知晓模型在真实应用中的潜力。
例如在图像分类任务中,验证集能帮我们快速判断当前模型是否对训练图像过度学习,测试集则能告诉我们模型对新采集的图像分类准确率如何。
- 提高模型可靠性 :合理划分数据集促使模型学习到更具普遍性的特征,而非特定数据的特殊性质。这使得模型在不同环境、不同数据分布下都有相对稳定的表现,增强了模型的实用性与可靠性,减少因数据偏差带来的错误决策风险。
三、实现的代码
# 将图片和标注数据按比例切分为 训练集和测试集
import shutil
import random
import os
# 原始路径
image_original_path = "data/mydata/newImage/"
label_original_path = "data/mydata/newLabel/"
cur_path = os.getcwd()
# 训练集路径
train_image_path = os.path.join(cur_path, "datasets/images/train/")
train_label_path = os.path.join(cur_path, "datasets/labels/train/")
# 验证集路径
val_image_path = os.path.join(cur_path, "datasets/images/val/")
val_label_path = os.path.join(cur_path, "datasets/labels/val/")
# 测试集路径
test_image_path = os.path.join(cur_path, "datasets/images/test/")
test_label_path = os.path.join(cur_path, "datasets/labels/test/")
# 训练集目录
list_train = os.path.join(cur_path, "datasets/train.txt")
list_val = os.path.join(cur_path, "datasets/val.txt")
list_test = os.path.join(cur_path, "datasets/test.txt")
train_percent = 0.8
val_percent = 0.1
test_percent = 0.1
def del_file(path):
for i in os.listdir(path):
file_data = path + "\\" + i
os.remove(file_data)
def mkdir():
if not os.path.exists(train_image_path):
os.makedirs(train_image_path)
else:
del_file(train_image_path)
if not os.path.exists(train_label_path):
os.makedirs(train_label_path)
else:
del_file(train_label_path)
if not os.path.exists(val_image_path):
os.makedirs(val_image_path)
else:
del_file(val_image_path)
if not os.path.exists(val_label_path):
os.makedirs(val_label_path)
else:
del_file(val_label_path)
if not os.path.exists(test_image_path):
os.makedirs(test_image_path)
else:
del_file(test_image_path)
if not os.path.exists(test_label_path):
os.makedirs(test_label_path)
else:
del_file(test_label_path)
def clearfile():
if os.path.exists(list_train):
os.remove(list_train)
if os.path.exists(list_val):
os.remove(list_val)
if os.path.exists(list_test):
os.remove(list_test)
def main():
mkdir()
clearfile()
file_train = open(list_train, 'w')
file_val = open(list_val, 'w')
file_test = open(list_test, 'w')
total_txt = os.listdir(label_original_path)
num_txt = len(total_txt)
list_all_txt = range(num_txt)
num_train = int(num_txt * train_percent)
num_val = int(num_txt * val_percent)
num_test = num_txt - num_train - num_val
train = random.sample(list_all_txt, num_train)
# train从list_all_txt取出num_train个元素
# 所以list_all_txt列表只剩下了这些元素
val_test = [i for i in list_all_txt if not i in train]
# 再从val_test取出num_val个元素,val_test剩下的元素就是test
val = random.sample(val_test, num_val)
print("训练集数目:{}, 验证集数目:{}, 测试集数目:{}".format(len(train), len(val), len(val_test) - len(val)))
for i in list_all_txt:
name = total_txt[i][:-4]
srcImage = image_original_path + name + '.jpg'
srcLabel = label_original_path + name + ".txt"
if i in train:
dst_train_Image = train_image_path + name + '.jpg'
dst_train_Label = train_label_path + name + '.txt'
shutil.copyfile(srcImage, dst_train_Image)
shutil.copyfile(srcLabel, dst_train_Label)
file_train.write(dst_train_Image + '\n')
elif i in val:
dst_val_Image = val_image_path + name + '.jpg'
dst_val_Label = val_label_path + name + '.txt'
shutil.copyfile(srcImage, dst_val_Image)
shutil.copyfile(srcLabel, dst_val_Label)
file_val.write(dst_val_Image + '\n')
else:
dst_test_Image = test_image_path + name + '.jpg'
dst_test_Label = test_label_path + name + '.txt'
shutil.copyfile(srcImage, dst_test_Image)
shutil.copyfile(srcLabel, dst_test_Label)
file_test.write(dst_test_Image + '\n')
file_train.close()
file_val.close()
file_test.close()
if __name__ == "__main__":
main()
四、使用说明
在此处指定自己的数据集路径,包含 图像 和 标签
image_original_path = "data/mydata/newImage/"
label_original_path = "data/mydata/newLabel/"
修改完成后,运行即可,便会在项目的根目录下生成
datasets
文件夹,文件夹内部的文件如下,
此时图像和标签均已划分完成
,只需要配置相应的文件即可开始训练。
五、相关疑问
对于数据量比较少的自建数据集,是先做数据扩充再进行数据划分?
建议先进行数据划分后再进行数据扩充,否则可能会影响其鲁棒性。
有些公开数据集已经划分好了训练集和验证集,我是否还需要设置一个测试集?
针对公开的已划分好的数据集,可以直接使用。针对自建的数据集还是建议划分成三类。
我在新项目中重新划分数据集进行对比试验时,是否需要保证,训练集,验证集和测试集中的数据一致?如何保证?
需要的,只需要将原本以划分好的数据集拷贝过来即可。