目录结构

 

 

dogsData.py

import json

import torch
import os, glob
import random, csv

from PIL import Image
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.transforms import InterpolationMode


class Dogs(Dataset):

    def __init__(self, root, resize, mode):
        super().__init__()
        self.root = root
        self.resize = resize
        self.nameLable = {}
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.nameLable[name] = len(self.nameLable.keys())

        if not os.path.exists(os.path.join(self.root, 'label.txt')):
            with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:
                f.write(json.dumps(self.nameLable, ensure_ascii=False))

        # print(self.nameLable)
        self.images, self.labels = self.load_csv('images.csv')
        # print(self.labels)

        if mode == 'train':
            self.images = self.images[:int(0.8*len(self.images))]
            self.labels = self.labels[:int(0.8*len(self.labels))]
        elif mode == 'val':
            self.images = self.images[int(0.8*len(self.images)):int(0.9*len(self.images))]
            self.labels = self.labels[int(0.8*len(self.labels)):int(0.9*len(self.labels))]
        else:
            self.images = self.images[int(0.9*len(self.images)):]
            self.labels = self.labels[int(0.9*len(self.labels)):]

    def load_csv(self, filename):

        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.nameLable.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            # print(len(images))

            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images:
                    name = img.split(os.sep)[-2]
                    label = self.nameLable[name]
                    writer.writerow([img, label])
            print('csv write succesful')

        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)
                images.append(img)
                labels.append(label)

        assert len(images) == len(labels)

        return images, labels

    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hot = (x-mean)/std
        # x = x_hat * std = mean
        # x : [c, w, h]
        # mean [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)

        x = x_hat * std + mean
        return x


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # print(idx, len(self.images), len(self.labels))
        img, label = self.images[idx], self.labels[idx]

        # 将字符串路径转换为tensor数据
        # print(self.resize, type(self.resize))
        tf = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)

        label = torch.tensor(label)

        return img, label



def main():

    import visdom
    import time

    viz = visdom.Visdom()

    # func1 通用
    db = Dogs('Images_Data_Dog', 224, 'train')
    # 取一张
    # x,y = next(iter(db))
    # print(x.shape, y)
    # viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))

    # 取一个batch
    loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
    print(len(loader))
    print(db.nameLable)
    # for x, y in loader:
    #     # print(x.shape, y)
    #     viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
    #     viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
    #     time.sleep(10)

    # # fun2
    # import torchvision
    # tf = transforms.Compose([
    #     transforms.Resize((64, 64)),
    #     transforms.RandomRotation(15),
    #     transforms.ToTensor(),
    # ])
    # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
    # loader = DataLoader(db, batch_size=32, shuffle=True)
    # print(len(loader))
    # for x, y in loader:
    #     # print(x.shape, y)
    #     viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
    #     viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
    #     time.sleep(10)

if __name__ == '__main__':
    main()

 

utils.py

import torch
from torch import nn


class Flatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

 

train.py

import os
import sys
base_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(base_path)
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)
import torch
import visdom
from torch import optim, nn
import torchvision

from torch.utils.data import DataLoader

from dogs_train.utils import Flatten
from dogsData import Dogs

from torchvision.models import resnet18




viz = visdom.Visdom()

batchsz = 32
lr = 1e-3
epochs = 20

device = torch.device('cuda')
torch.manual_seed(1234)


train_db = Dogs('Images_Data_Dog', 224, mode='train')
val_db = Dogs('Images_Data_Dog', 224, mode='val')
test_db = Dogs('Images_Data_Dog', 224, mode='test')

train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)


def evalute(model, loader):
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            logist = model(x)
            pred = logist.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct/total


def main():

    # model = ResNet18(5).to(device)
    trained_model = resnet18(pretrained=True)
    model = nn.Sequential(*list(trained_model.children())[:-1],
                          Flatten(),  # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 27)
                          ).to(device)

    x = torch.randn(2, 3, 224, 224).to(device)
    print(model(x).shape)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()

    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):

        for step, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
        if epoch % 2 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_acc = val_acc
                best_epoch = epoch
                torch.save(model.state_dict(), 'best.mdl')

                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc', best_acc, 'best epoch', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loader from ckpt')

    test_acc = evalute(model, test_loader)
    print(test_acc)

if __name__ == '__main__':
    main()