目录结构
dogsData.py
import json import torch import os, glob import random, csv from PIL import Image from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torchvision.transforms import InterpolationMode class Dogs(Dataset): def __init__(self, root, resize, mode): super().__init__() self.root = root self.resize = resize self.nameLable = {} for name in sorted(os.listdir(os.path.join(root))): if not os.path.isdir(os.path.join(root, name)): continue self.nameLable[name] = len(self.nameLable.keys()) if not os.path.exists(os.path.join(self.root, 'label.txt')): with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f: f.write(json.dumps(self.nameLable, ensure_ascii=False)) # print(self.nameLable) self.images, self.labels = self.load_csv('images.csv') # print(self.labels) if mode == 'train': self.images = self.images[:int(0.8*len(self.images))] self.labels = self.labels[:int(0.8*len(self.labels))] elif mode == 'val': self.images = self.images[int(0.8*len(self.images)):int(0.9*len(self.images))] self.labels = self.labels[int(0.8*len(self.labels)):int(0.9*len(self.labels))] else: self.images = self.images[int(0.9*len(self.images)):] self.labels = self.labels[int(0.9*len(self.labels)):] def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)): images = [] for name in self.nameLable.keys(): images += glob.glob(os.path.join(self.root, name, '*.png')) images += glob.glob(os.path.join(self.root, name, '*.jpg')) images += glob.glob(os.path.join(self.root, name, '*.jpeg')) # print(len(images)) random.shuffle(images) with open(os.path.join(self.root, filename), mode='w', newline='') as f: writer = csv.writer(f) for img in images: name = img.split(os.sep)[-2] label = self.nameLable[name] writer.writerow([img, label]) print('csv write succesful') images, labels = [], [] with open(os.path.join(self.root, filename)) as f: reader = csv.reader(f) for row in reader: img, label = row label = int(label) images.append(img) labels.append(label) assert len(images) == len(labels) return images, labels def denormalize(self, x_hat): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] # x_hot = (x-mean)/std # x = x_hat * std = mean # x : [c, w, h] # mean [3] => [3, 1, 1] mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) std = torch.tensor(std).unsqueeze(1).unsqueeze(1) x = x_hat * std + mean return x def __len__(self): return len(self.images) def __getitem__(self, idx): # print(idx, len(self.images), len(self.labels)) img, label = self.images[idx], self.labels[idx] # 将字符串路径转换为tensor数据 # print(self.resize, type(self.resize)) tf = transforms.Compose([ lambda x: Image.open(x).convert('RGB'), transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))), transforms.RandomRotation(15), transforms.CenterCrop(self.resize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = tf(img) label = torch.tensor(label) return img, label def main(): import visdom import time viz = visdom.Visdom() # func1 通用 db = Dogs('Images_Data_Dog', 224, 'train') # 取一张 # x,y = next(iter(db)) # print(x.shape, y) # viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x')) # 取一个batch loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8) print(len(loader)) print(db.nameLable) # for x, y in loader: # # print(x.shape, y) # viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch')) # viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y')) # time.sleep(10) # # fun2 # import torchvision # tf = transforms.Compose([ # transforms.Resize((64, 64)), # transforms.RandomRotation(15), # transforms.ToTensor(), # ]) # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf) # loader = DataLoader(db, batch_size=32, shuffle=True) # print(len(loader)) # for x, y in loader: # # print(x.shape, y) # viz.images(x, nrow=8, win='batch', opts=dict(title='batch')) # viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y')) # time.sleep(10) if __name__ == '__main__': main()
utils.py
import torch from torch import nn class Flatten(nn.Module): def __init__(self): super().__init__() def forward(self, x): shape = torch.prod(torch.tensor(x.shape[1:])).item() return x.view(-1, shape)
train.py
import os import sys base_path = os.path.dirname(os.path.abspath(__file__)) sys.path.append(base_path) base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(base_path) import torch import visdom from torch import optim, nn import torchvision from torch.utils.data import DataLoader from dogs_train.utils import Flatten from dogsData import Dogs from torchvision.models import resnet18 viz = visdom.Visdom() batchsz = 32 lr = 1e-3 epochs = 20 device = torch.device('cuda') torch.manual_seed(1234) train_db = Dogs('Images_Data_Dog', 224, mode='train') val_db = Dogs('Images_Data_Dog', 224, mode='val') test_db = Dogs('Images_Data_Dog', 224, mode='test') train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4) val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2) test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2) def evalute(model, loader): correct = 0 total = len(loader.dataset) for x, y in loader: x = x.to(device) y = y.to(device) with torch.no_grad(): logist = model(x) pred = logist.argmax(dim=1) correct += torch.eq(pred, y).sum().float().item() return correct/total def main(): # model = ResNet18(5).to(device) trained_model = resnet18(pretrained=True) model = nn.Sequential(*list(trained_model.children())[:-1], Flatten(), # [b, 512, 1, 1] => [b, 512] nn.Linear(512, 27) ).to(device) x = torch.randn(2, 3, 224, 224).to(device) print(model(x).shape) optimizer = optim.Adam(model.parameters(), lr=lr) criteon = nn.CrossEntropyLoss() best_acc, best_epoch = 0, 0 global_step = 0 viz.line([0], [-1], win='loss', opts=dict(title='loss')) viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc')) for epoch in range(epochs): for step, (x, y) in enumerate(train_loader): x = x.to(device) y = y.to(device) logits = model(x) loss = criteon(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() viz.line([loss.item()], [global_step], win='loss', update='append') global_step += 1 if epoch % 2 == 0: val_acc = evalute(model, val_loader) if val_acc > best_acc: best_acc = val_acc best_epoch = epoch torch.save(model.state_dict(), 'best.mdl') viz.line([val_acc], [global_step], win='val_acc', update='append') print('best acc', best_acc, 'best epoch', best_epoch) model.load_state_dict(torch.load('best.mdl')) print('loader from ckpt') test_acc = evalute(model, test_loader) print(test_acc) if __name__ == '__main__': main()