1. 实战Kaggle比赛图像分类CIFAR10¶
① 比赛的网址是 https://www.kaggle.com/c/cifar-10
In [1]:
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
In [2]:
# 我们提供包含前1000个训练图像和5个随即测试图像的数据集的小规模样本
# cifar10_tiny是cifar10中每一个类把前面一千个训练图片拿出来,测试是每一个类挑五个图片
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
'2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
demo = True
if demo:
data_dir = d2l.download_extract('cifar10_tiny')
else:
data_dir = '../data/cifar-10'
In [3]:
# 整理数据集
def read_csv_labels(fname):
"""读取 'fname' 来给标签字典返回一个文件名。"""
with open(fname, 'r') as f:
lines = f.readlines()[1:] # 一行一行读进来,每一行为列表中一个元素
tokens = [l.rstrip().split(',') for l in lines] # 遍历列表每一个元素,切分
return dict(((name, label) for name, label in tokens))
labels = read_csv_labels(os.path.join(data_dir,'trainLabels.csv'))
labels
Out[3]:
{'1': 'frog',
'2': 'truck',
'3': 'truck',
'4': 'deer',
'5': 'automobile',
'6': 'automobile',
'7': 'bird',
'8': 'horse',
'9': 'ship',
'10': 'cat',
'11': 'deer',
'12': 'horse',
'13': 'horse',
'14': 'bird',
'15': 'truck',
'16': 'truck',
'17': 'truck',
'18': 'cat',
'19': 'bird',
'20': 'frog',
'21': 'deer',
'22': 'cat',
'23': 'frog',
'24': 'frog',
'25': 'bird',
'26': 'frog',
'27': 'cat',
'28': 'dog',
'29': 'deer',
'30': 'airplane',
'31': 'airplane',
'32': 'truck',
'33': 'automobile',
'34': 'cat',
'35': 'deer',
'36': 'airplane',
'37': 'cat',
'38': 'horse',
'39': 'cat',
'40': 'cat',
'41': 'dog',
'42': 'bird',
'43': 'bird',
'44': 'horse',
'45': 'automobile',
'46': 'automobile',
'47': 'automobile',
'48': 'bird',
'49': 'bird',
'50': 'airplane',
'51': 'truck',
'52': 'dog',
'53': 'horse',
'54': 'truck',
'55': 'bird',
'56': 'bird',
'57': 'dog',
'58': 'bird',
'59': 'deer',
'60': 'cat',
'61': 'automobile',
'62': 'automobile',
'63': 'ship',
'64': 'bird',
'65': 'automobile',
'66': 'automobile',
'67': 'deer',
'68': 'truck',
'69': 'horse',
'70': 'ship',
'71': 'dog',
'72': 'truck',
'73': 'frog',
'74': 'horse',
'75': 'cat',
'76': 'automobile',
'77': 'truck',
'78': 'airplane',
'79': 'cat',
'80': 'automobile',
'81': 'cat',
'82': 'dog',
'83': 'deer',
'84': 'dog',
'85': 'horse',
'86': 'horse',
'87': 'deer',
'88': 'horse',
'89': 'truck',
'90': 'deer',
'91': 'bird',
'92': 'cat',
'93': 'ship',
'94': 'airplane',
'95': 'automobile',
'96': 'frog',
'97': 'automobile',
'98': 'automobile',
'99': 'deer',
'100': 'automobile',
'101': 'ship',
'102': 'cat',
'103': 'truck',
'104': 'frog',
'105': 'frog',
'106': 'automobile',
'107': 'ship',
'108': 'dog',
'109': 'bird',
'110': 'truck',
'111': 'truck',
'112': 'ship',
'113': 'automobile',
'114': 'horse',
'115': 'horse',
'116': 'airplane',
'117': 'airplane',
'118': 'frog',
'119': 'truck',
'120': 'automobile',
'121': 'bird',
'122': 'bird',
'123': 'truck',
'124': 'bird',
'125': 'frog',
'126': 'frog',
'127': 'automobile',
'128': 'truck',
'129': 'dog',
'130': 'airplane',
'131': 'deer',
'132': 'horse',
'133': 'frog',
'134': 'horse',
'135': 'automobile',
'136': 'ship',
'137': 'automobile',
'138': 'automobile',
'139': 'bird',
'140': 'ship',
'141': 'automobile',
'142': 'cat',
'143': 'cat',
'144': 'frog',
'145': 'bird',
'146': 'deer',
'147': 'truck',
'148': 'truck',
'149': 'dog',
'150': 'deer',
'151': 'cat',
'152': 'frog',
'153': 'horse',
'154': 'deer',
'155': 'frog',
'156': 'ship',
'157': 'dog',
'158': 'dog',
'159': 'deer',
'160': 'cat',
'161': 'automobile',
'162': 'ship',
'163': 'deer',
'164': 'horse',
'165': 'frog',
'166': 'airplane',
'167': 'truck',
'168': 'dog',
'169': 'automobile',
'170': 'cat',
'171': 'ship',
'172': 'bird',
'173': 'horse',
'174': 'dog',
'175': 'cat',
'176': 'deer',
'177': 'automobile',
'178': 'dog',
'179': 'horse',
'180': 'airplane',
'181': 'deer',
'182': 'horse',
'183': 'dog',
'184': 'dog',
'185': 'automobile',
'186': 'airplane',
'187': 'truck',
'188': 'frog',
'189': 'truck',
'190': 'airplane',
'191': 'ship',
'192': 'horse',
'193': 'ship',
'194': 'ship',
'195': 'bird',
'196': 'dog',
'197': 'bird',
'198': 'cat',
'199': 'dog',
'200': 'airplane',
'201': 'frog',
'202': 'automobile',
'203': 'truck',
'204': 'cat',
'205': 'frog',
'206': 'truck',
'207': 'automobile',
'208': 'cat',
'209': 'truck',
'210': 'frog',
'211': 'frog',
'212': 'horse',
'213': 'automobile',
'214': 'airplane',
'215': 'truck',
'216': 'dog',
'217': 'ship',
'218': 'dog',
'219': 'bird',
'220': 'truck',
'221': 'airplane',
'222': 'ship',
'223': 'ship',
'224': 'airplane',
'225': 'frog',
'226': 'truck',
'227': 'automobile',
'228': 'automobile',
'229': 'frog',
'230': 'cat',
'231': 'horse',
'232': 'frog',
'233': 'frog',
'234': 'airplane',
'235': 'frog',
'236': 'frog',
'237': 'automobile',
'238': 'horse',
'239': 'automobile',
'240': 'dog',
'241': 'ship',
'242': 'cat',
'243': 'frog',
'244': 'frog',
'245': 'ship',
'246': 'frog',
'247': 'ship',
'248': 'deer',
'249': 'frog',
'250': 'frog',
'251': 'automobile',
'252': 'cat',
'253': 'ship',
'254': 'cat',
'255': 'deer',
'256': 'automobile',
'257': 'horse',
'258': 'automobile',
'259': 'cat',
'260': 'ship',
'261': 'dog',
'262': 'automobile',
'263': 'automobile',
'264': 'deer',
'265': 'airplane',
'266': 'truck',
'267': 'cat',
'268': 'horse',
'269': 'deer',
'270': 'truck',
'271': 'truck',
'272': 'bird',
'273': 'deer',
'274': 'truck',
'275': 'truck',
'276': 'automobile',
'277': 'airplane',
'278': 'dog',
'279': 'truck',
'280': 'airplane',
'281': 'ship',
'282': 'bird',
'283': 'automobile',
'284': 'bird',
'285': 'airplane',
'286': 'dog',
'287': 'frog',
'288': 'cat',
'289': 'bird',
'290': 'horse',
'291': 'ship',
'292': 'ship',
'293': 'frog',
'294': 'airplane',
'295': 'horse',
'296': 'truck',
'297': 'deer',
'298': 'dog',
'299': 'frog',
'300': 'deer',
'301': 'bird',
'302': 'automobile',
'303': 'automobile',
'304': 'bird',
'305': 'automobile',
'306': 'dog',
'307': 'truck',
'308': 'truck',
'309': 'airplane',
'310': 'ship',
'311': 'deer',
'312': 'automobile',
'313': 'automobile',
'314': 'frog',
'315': 'cat',
'316': 'cat',
'317': 'truck',
'318': 'airplane',
'319': 'horse',
'320': 'truck',
'321': 'horse',
'322': 'horse',
'323': 'truck',
'324': 'automobile',
'325': 'dog',
'326': 'automobile',
'327': 'frog',
'328': 'frog',
'329': 'ship',
'330': 'horse',
'331': 'automobile',
'332': 'cat',
'333': 'airplane',
'334': 'cat',
'335': 'cat',
'336': 'bird',
'337': 'deer',
'338': 'dog',
'339': 'horse',
'340': 'dog',
'341': 'truck',
'342': 'airplane',
'343': 'cat',
'344': 'deer',
'345': 'airplane',
'346': 'deer',
'347': 'deer',
'348': 'frog',
'349': 'airplane',
'350': 'airplane',
'351': 'frog',
'352': 'frog',
'353': 'airplane',
'354': 'ship',
'355': 'automobile',
'356': 'frog',
'357': 'bird',
'358': 'truck',
'359': 'bird',
'360': 'dog',
'361': 'truck',
'362': 'frog',
'363': 'horse',
'364': 'deer',
'365': 'automobile',
'366': 'ship',
'367': 'horse',
'368': 'cat',
'369': 'frog',
'370': 'truck',
'371': 'cat',
'372': 'airplane',
'373': 'deer',
'374': 'airplane',
'375': 'dog',
'376': 'automobile',
'377': 'airplane',
'378': 'cat',
'379': 'deer',
'380': 'ship',
'381': 'dog',
'382': 'deer',
'383': 'horse',
'384': 'bird',
'385': 'cat',
'386': 'truck',
'387': 'horse',
'388': 'frog',
'389': 'horse',
'390': 'automobile',
'391': 'deer',
'392': 'horse',
'393': 'airplane',
'394': 'automobile',
'395': 'horse',
'396': 'cat',
'397': 'automobile',
'398': 'ship',
'399': 'deer',
'400': 'deer',
'401': 'bird',
'402': 'airplane',
'403': 'bird',
'404': 'bird',
'405': 'airplane',
'406': 'airplane',
'407': 'truck',
'408': 'airplane',
'409': 'truck',
'410': 'frog',
'411': 'ship',
'412': 'bird',
'413': 'horse',
'414': 'horse',
'415': 'deer',
'416': 'airplane',
'417': 'cat',
'418': 'airplane',
'419': 'ship',
'420': 'truck',
'421': 'deer',
'422': 'bird',
'423': 'horse',
'424': 'bird',
'425': 'dog',
'426': 'bird',
'427': 'dog',
'428': 'automobile',
'429': 'truck',
'430': 'deer',
'431': 'ship',
'432': 'dog',
'433': 'automobile',
'434': 'horse',
'435': 'deer',
'436': 'deer',
'437': 'airplane',
'438': 'frog',
'439': 'truck',
'440': 'airplane',
'441': 'horse',
'442': 'ship',
'443': 'ship',
'444': 'truck',
'445': 'truck',
'446': 'cat',
'447': 'cat',
'448': 'deer',
'449': 'airplane',
'450': 'deer',
'451': 'dog',
'452': 'frog',
'453': 'frog',
'454': 'airplane',
'455': 'automobile',
'456': 'airplane',
'457': 'ship',
'458': 'airplane',
'459': 'deer',
'460': 'ship',
'461': 'ship',
'462': 'automobile',
'463': 'dog',
'464': 'bird',
'465': 'frog',
'466': 'ship',
'467': 'automobile',
'468': 'airplane',
'469': 'airplane',
'470': 'horse',
'471': 'horse',
'472': 'dog',
'473': 'truck',
'474': 'frog',
'475': 'bird',
'476': 'ship',
'477': 'cat',
'478': 'deer',
'479': 'horse',
'480': 'cat',
'481': 'truck',
'482': 'airplane',
'483': 'automobile',
'484': 'bird',
'485': 'deer',
'486': 'ship',
'487': 'automobile',
'488': 'ship',
'489': 'frog',
'490': 'deer',
'491': 'deer',
'492': 'dog',
'493': 'horse',
'494': 'automobile',
'495': 'cat',
'496': 'truck',
'497': 'ship',
'498': 'airplane',
'499': 'automobile',
'500': 'horse',
'501': 'dog',
'502': 'ship',
'503': 'bird',
'504': 'ship',
'505': 'airplane',
'506': 'deer',
'507': 'automobile',
'508': 'ship',
'509': 'truck',
'510': 'ship',
'511': 'bird',
'512': 'truck',
'513': 'truck',
'514': 'bird',
'515': 'horse',
'516': 'dog',
'517': 'horse',
'518': 'cat',
'519': 'ship',
'520': 'ship',
'521': 'deer',
'522': 'deer',
'523': 'bird',
'524': 'horse',
'525': 'automobile',
'526': 'frog',
'527': 'deer',
'528': 'airplane',
'529': 'deer',
'530': 'frog',
'531': 'truck',
'532': 'horse',
'533': 'frog',
'534': 'bird',
'535': 'dog',
'536': 'dog',
'537': 'automobile',
'538': 'horse',
'539': 'bird',
'540': 'bird',
'541': 'bird',
'542': 'truck',
'543': 'dog',
'544': 'deer',
'545': 'bird',
'546': 'horse',
'547': 'ship',
'548': 'automobile',
'549': 'cat',
'550': 'deer',
'551': 'cat',
'552': 'horse',
'553': 'frog',
'554': 'truck',
'555': 'ship',
'556': 'airplane',
'557': 'frog',
'558': 'airplane',
'559': 'bird',
'560': 'bird',
'561': 'bird',
'562': 'automobile',
'563': 'ship',
'564': 'deer',
'565': 'airplane',
'566': 'automobile',
'567': 'ship',
'568': 'ship',
'569': 'automobile',
'570': 'dog',
'571': 'horse',
'572': 'frog',
'573': 'deer',
'574': 'dog',
'575': 'ship',
'576': 'horse',
'577': 'automobile',
'578': 'truck',
'579': 'automobile',
'580': 'truck',
'581': 'ship',
'582': 'deer',
'583': 'horse',
'584': 'cat',
'585': 'ship',
'586': 'ship',
'587': 'bird',
'588': 'frog',
'589': 'frog',
'590': 'horse',
'591': 'automobile',
'592': 'frog',
'593': 'ship',
'594': 'automobile',
'595': 'truck',
'596': 'horse',
'597': 'ship',
'598': 'cat',
'599': 'airplane',
'600': 'automobile',
'601': 'airplane',
'602': 'ship',
'603': 'ship',
'604': 'cat',
'605': 'airplane',
'606': 'airplane',
'607': 'automobile',
'608': 'dog',
'609': 'airplane',
'610': 'ship',
'611': 'ship',
'612': 'horse',
'613': 'truck',
'614': 'truck',
'615': 'airplane',
'616': 'truck',
'617': 'deer',
'618': 'automobile',
'619': 'cat',
'620': 'frog',
'621': 'frog',
'622': 'deer',
'623': 'deer',
'624': 'horse',
'625': 'dog',
'626': 'frog',
'627': 'airplane',
'628': 'ship',
'629': 'airplane',
'630': 'cat',
'631': 'bird',
'632': 'ship',
'633': 'deer',
'634': 'frog',
'635': 'truck',
'636': 'truck',
'637': 'horse',
'638': 'airplane',
'639': 'cat',
'640': 'cat',
'641': 'frog',
'642': 'horse',
'643': 'deer',
'644': 'truck',
'645': 'automobile',
'646': 'frog',
'647': 'bird',
'648': 'horse',
'649': 'bird',
'650': 'bird',
'651': 'airplane',
'652': 'frog',
'653': 'horse',
'654': 'dog',
'655': 'horse',
'656': 'frog',
'657': 'ship',
'658': 'truck',
'659': 'airplane',
'660': 'truck',
'661': 'deer',
'662': 'deer',
'663': 'horse',
'664': 'airplane',
'665': 'truck',
'666': 'deer',
'667': 'truck',
'668': 'frog',
'669': 'truck',
'670': 'deer',
'671': 'dog',
'672': 'horse',
'673': 'truck',
'674': 'bird',
'675': 'deer',
'676': 'dog',
'677': 'automobile',
'678': 'deer',
'679': 'cat',
'680': 'truck',
'681': 'frog',
'682': 'dog',
'683': 'frog',
'684': 'truck',
'685': 'cat',
'686': 'cat',
'687': 'dog',
'688': 'airplane',
'689': 'horse',
'690': 'bird',
'691': 'automobile',
'692': 'cat',
'693': 'frog',
'694': 'deer',
'695': 'airplane',
'696': 'airplane',
'697': 'bird',
'698': 'dog',
'699': 'airplane',
'700': 'automobile',
'701': 'airplane',
'702': 'bird',
'703': 'cat',
'704': 'truck',
'705': 'ship',
'706': 'deer',
'707': 'truck',
'708': 'ship',
'709': 'airplane',
'710': 'bird',
'711': 'frog',
'712': 'deer',
'713': 'deer',
'714': 'airplane',
'715': 'automobile',
'716': 'ship',
'717': 'ship',
'718': 'cat',
'719': 'frog',
'720': 'truck',
'721': 'frog',
'722': 'frog',
'723': 'horse',
'724': 'ship',
'725': 'bird',
'726': 'deer',
'727': 'dog',
'728': 'horse',
'729': 'frog',
'730': 'dog',
'731': 'cat',
'732': 'airplane',
'733': 'dog',
'734': 'airplane',
'735': 'dog',
'736': 'airplane',
'737': 'ship',
'738': 'bird',
'739': 'frog',
'740': 'horse',
'741': 'cat',
'742': 'ship',
'743': 'bird',
'744': 'automobile',
'745': 'horse',
'746': 'frog',
'747': 'horse',
'748': 'automobile',
'749': 'airplane',
'750': 'truck',
'751': 'dog',
'752': 'dog',
'753': 'airplane',
'754': 'automobile',
'755': 'horse',
'756': 'frog',
'757': 'truck',
'758': 'airplane',
'759': 'deer',
'760': 'horse',
'761': 'horse',
'762': 'automobile',
'763': 'dog',
'764': 'truck',
'765': 'deer',
'766': 'airplane',
'767': 'ship',
'768': 'dog',
'769': 'truck',
'770': 'truck',
'771': 'frog',
'772': 'horse',
'773': 'automobile',
'774': 'ship',
'775': 'cat',
'776': 'bird',
'777': 'cat',
'778': 'ship',
'779': 'bird',
'780': 'bird',
'781': 'deer',
'782': 'frog',
'783': 'airplane',
'784': 'airplane',
'785': 'dog',
'786': 'cat',
'787': 'ship',
'788': 'bird',
'789': 'cat',
'790': 'horse',
'791': 'bird',
'792': 'truck',
'793': 'cat',
'794': 'ship',
'795': 'horse',
'796': 'ship',
'797': 'bird',
'798': 'horse',
'799': 'truck',
'800': 'airplane',
'801': 'bird',
'802': 'cat',
'803': 'bird',
'804': 'bird',
'805': 'bird',
'806': 'cat',
'807': 'cat',
'808': 'frog',
'809': 'bird',
'810': 'cat',
'811': 'bird',
'812': 'ship',
'813': 'airplane',
'814': 'dog',
'815': 'dog',
'816': 'automobile',
'817': 'deer',
'818': 'dog',
'819': 'frog',
'820': 'frog',
'821': 'bird',
'822': 'horse',
'823': 'airplane',
'824': 'automobile',
'825': 'horse',
'826': 'horse',
'827': 'ship',
'828': 'bird',
'829': 'truck',
'830': 'bird',
'831': 'bird',
'832': 'deer',
'833': 'bird',
'834': 'automobile',
'835': 'automobile',
'836': 'automobile',
'837': 'frog',
'838': 'frog',
'839': 'frog',
'840': 'dog',
'841': 'automobile',
'842': 'automobile',
'843': 'horse',
'844': 'airplane',
'845': 'deer',
'846': 'cat',
'847': 'cat',
'848': 'horse',
'849': 'automobile',
'850': 'bird',
'851': 'cat',
'852': 'dog',
'853': 'dog',
'854': 'dog',
'855': 'frog',
'856': 'automobile',
'857': 'deer',
'858': 'cat',
'859': 'horse',
'860': 'ship',
'861': 'ship',
'862': 'cat',
'863': 'frog',
'864': 'frog',
'865': 'bird',
'866': 'cat',
'867': 'airplane',
'868': 'truck',
'869': 'deer',
'870': 'cat',
'871': 'ship',
'872': 'airplane',
'873': 'airplane',
'874': 'automobile',
'875': 'automobile',
'876': 'dog',
'877': 'deer',
'878': 'truck',
'879': 'cat',
'880': 'automobile',
'881': 'ship',
'882': 'truck',
'883': 'cat',
'884': 'truck',
'885': 'truck',
'886': 'bird',
'887': 'truck',
'888': 'deer',
'889': 'ship',
'890': 'bird',
'891': 'truck',
'892': 'ship',
'893': 'ship',
'894': 'automobile',
'895': 'dog',
'896': 'cat',
'897': 'frog',
'898': 'ship',
'899': 'horse',
'900': 'frog',
'901': 'truck',
'902': 'ship',
'903': 'airplane',
'904': 'frog',
'905': 'deer',
'906': 'airplane',
'907': 'airplane',
'908': 'bird',
'909': 'dog',
'910': 'ship',
'911': 'bird',
'912': 'airplane',
'913': 'bird',
'914': 'horse',
'915': 'frog',
'916': 'truck',
'917': 'horse',
'918': 'automobile',
'919': 'dog',
'920': 'dog',
'921': 'frog',
'922': 'frog',
'923': 'cat',
'924': 'frog',
'925': 'bird',
'926': 'deer',
'927': 'horse',
'928': 'airplane',
'929': 'dog',
'930': 'frog',
'931': 'deer',
'932': 'frog',
'933': 'dog',
'934': 'bird',
'935': 'deer',
'936': 'frog',
'937': 'automobile',
'938': 'frog',
'939': 'airplane',
'940': 'deer',
'941': 'airplane',
'942': 'cat',
'943': 'automobile',
'944': 'ship',
'945': 'dog',
'946': 'deer',
'947': 'deer',
'948': 'automobile',
'949': 'horse',
'950': 'cat',
'951': 'truck',
'952': 'deer',
'953': 'horse',
'954': 'truck',
'955': 'horse',
'956': 'cat',
'957': 'horse',
'958': 'bird',
'959': 'ship',
'960': 'deer',
'961': 'frog',
'962': 'frog',
'963': 'automobile',
'964': 'bird',
'965': 'truck',
'966': 'airplane',
'967': 'deer',
'968': 'ship',
'969': 'horse',
'970': 'cat',
'971': 'truck',
'972': 'ship',
'973': 'horse',
'974': 'horse',
'975': 'airplane',
'976': 'bird',
'977': 'deer',
'978': 'automobile',
'979': 'automobile',
'980': 'deer',
'981': 'automobile',
'982': 'dog',
'983': 'deer',
'984': 'airplane',
'985': 'dog',
'986': 'frog',
'987': 'bird',
'988': 'ship',
'989': 'dog',
'990': 'airplane',
'991': 'bird',
'992': 'automobile',
'993': 'cat',
'994': 'dog',
'995': 'horse',
'996': 'cat',
'997': 'dog',
'998': 'automobile',
'999': 'cat',
'1000': 'dog'}
In [4]:
# 将验证集从原始的训练集中拆分出来
# train文件夹下有所有train的图片,test文件夹下有所有test图片
# 把train文件夹下所有类的图片创建一个类名文件夹,然后搬到对应文件夹下
def copyfile(filename, target_dir):
"""将文件复制到目标目录"""
os.makedirs(target_dir, exist_ok=True)
shutil.copy(filename, target_dir)
def reorg_train_valid(data_dir, labels, valid_ratio):
n = collections.Counter(labels.values()).most_common()[-1][1]
n_valid_per_label = max(1,math.floor(n * valid_ratio))
label_count = {}
for train_file in os.listdir(os.path.join(data_dir,'train')):
label = labels[train_file.split('.')[0]]
fname = os.path.join(data_dir,'train',train_file)
copyfile(fname,os.path.join(data_dir,'train_valid_test','train_valid',label))
if label not in label_count or label_count[label] < n_valid_per_label:
copyfile(fname,os.path.join(data_dir,'train_valid_test','valid',label))
label_count[label] = label_count.get(label,0) + 1
else:
copyfile(fname,os.path.join(data_dir,'train_valid_test','train',label))
return n_valid_per_label
In [5]:
# 在预测期间整理测试集,以方便读取
def reorg_test(data_dir):
for test_file in os.listdir(os.path.join(data_dir,'test')):
copyfile(os.path.join(data_dir,'test',test_file),
os.path.join(data_dir,'train_valid_test','test','unknown')) # unknown为 test文件夹里面的一个文件夹
In [6]:
# 调用前面定义的函数,前面只是定义函数,这里是调用
def reorg_cifar10_data(data_dir,valid_ratio):
labels = read_csv_labels(os.path.join(data_dir,'trainLabels.csv'))
reorg_train_valid(data_dir,labels,valid_ratio)
reorg_test(data_dir)
batch_size = 32 if demo else 128
valid_ratio = 0.1 # train 数据里面百分之九十用来训练,剩下百分之十用来验证
reorg_cifar10_data(data_dir, valid_ratio)
In [7]:
# 图像增广
transform_train = torchvision.transforms.Compose([
torchvision.transforms.Resize(40),
torchvision.transforms.RandomResizedCrop(32,scale=(0.64,1.0),ratio=(1.0,1.0)),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.4914,0.4822,0.4465],
[0.2023,0.1994,0.2010]) ])
transform_test = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.4914,0.4822,0.4465],
[0.2023,0.1994,0.2010]) ])
In [8]:
# 读取由原始图像组成的数据集
train_ds, train_valid_ds = [
torchvision.datasets.ImageFolder(
os.path.join(data_dir,'train_valid_test',folder),
transform=transform_train) for folder in ['train','train_valid'] ]
valid_ds, test_ds = [
torchvision.datasets.ImageFolder(
os.path.join(data_dir,'train_valid_test',folder),
transform=transform_test) for folder in ['valid','test'] ]
In [9]:
# 指定上面定义的所有图像增广操作
train_iter, train_valid_iter = [
torch.utils.data.DataLoader(dataset,batch_size,shuffle=True,drop_last=True)
for dataset in (train_ds, train_valid_ds) ]
valid_iter = torch.utils.data.DataLoader(valid_ds,batch_size,shuffle=False,drop_last=True)
test_iter = torch.utils.data.DataLoader(test_ds,batch_size,shuffle=False,drop_last=False)
In [10]:
# 模型
def get_net():
num_classes = 10
net = d2l.resnet18(num_classes,3) # 3表示数值三通道,彩色图片
return net
loss = nn.CrossEntropyLoss(reduction="none") # reduction="none" 表示不要把loss加起来sum
In [11]:
# 训练函数
def train(net, train_iter, valid_iter, num_epoch, lr, wd, devices, lr_period, lr_decay): # 每隔一段时间的lr_period把学习率lr_decay降低点
trainer = torch.optim.SGD(net.parameters(),lr=lr,momentum=0.9,weight_decay=wd)
scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
num_batches, timer = len(train_iter), d2l.Timer()
legend = ['train loss','train acc']
if valid_iter is not None:
legend.append('valid acc')
animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=legend)
net = nn.DataParallel(net,device_ids=devices).to(devices[0])
for epoch in range(num_epochs):
net.train()
metric = d2l.Accumulator(3)
for i,(features,labels) in enumerate(train_iter):
timer.start()
l, acc = d2l.train_batch_ch13(net,features,labels,loss,trainer,devices)
metric.add(l,acc,labels.shape[0])
timer.stop()
if (i+1) % (num_batches // 5) == 0 or i == num_batches -1:
animator.add(epoch + (i + 1) / num_batches, (metric[0]/metric[2], metric[1]/metric[2],None))
if valid_iter is not None:
valid_acc = d2l.evaluate_accuracy_gpu(net,valid_iter)
animator.add(epoch+1,(None,None,valid_acc))
scheduler.step()
measures = (f'train loss {metric[0] / metric[2]:.3f},'
f'train acc {metric[1] / metric[2]:.3f}')
if valid_iter is not None:
measures += f', valid acc {valid_acc:.3f}'
print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
f' examples/sec on {str(devices)}')
In [12]:
# 训练和验证模型
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
lr_period, lr_decay, net = 4, 0.9, get_net()
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)
train loss 0.618,train acc 0.790, valid acc 0.359 623.1 examples/sec on [device(type='cuda', index=0)]
In [13]:
# 对测试集进行分类并提交结果
net, preds = get_net(), []
train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period, lr_decay)
for X, _ in test_iter:
y_hat = net(X.to(devices[0]))
preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())
sorted_ids = list(range(1,len(test_ds)+1))
sorted_ids.sort(key=lambda x: str(x))
df = pd.DataFrame({'id':sorted_ids,'label':preds})
df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])
df.to_csv('submission.csv',index=False)
train loss 0.560,train acc 0.805 859.2 examples/sec on [device(type='cuda', index=0)]