efficientAD 源码阅读

news/2024/7/24 4:22:44 标签: 机器学习, 人工智能

目录

1. dataset

2. dataloader 

3. model

3.1 teacher和student网络结构

3.2 autoencoder

3.3 create model 

4. train

4.1 求训练数据集的特征均值和标准差 

4.2 student loss L_ST

4.3  autoencoder loss (teacher and ae)

4.4 student additional loss (student and ae)


 

1. dataset

# TrainDataset只是使用图片,对图片做一些transform后返回
class ImageFolderWithoutTarget(ImageFolder):
    def __getitem__(self, index):
        sample, target = super().__getitem__(index)
        return sample


# TestDataset对图片做一些transform后,返回: 图片、类别、文件路径
class ImageFolderWithPath(ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample, target = super().__getitem__(index)
        return sample, target, path
# data transform
default_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_ae = transforms.RandomChoice([
    transforms.ColorJitter(brightness=0.2),
    transforms.ColorJitter(contrast=0.2),
    transforms.ColorJitter(saturation=0.2)
])


# student和teacher使用默认的transform
# autoencoder网络在默认的基础上增加了一个颜色数据增强
def train_transform(image):
    return default_transform(image), default_transform(transform_ae(image))

# 1, load data
full_train_set = ImageFolderWithoutTarget(  # 将图片缩放到固定尺度,只是返回图片。
    os.path.join(dataset_path, config.subdataset, 'train'),
    transform=transforms.Lambda(train_transform))  # 返回两张图片,一张是默认transform后,一张是加了颜色数据增强后的

test_set = ImageFolderWithPath(   # 返回原始图片、对应的类别:broken_large、broken_small、contamination、good。
    os.path.join(dataset_path, config.subdataset, 'test'))

if config.dataset == 'mvtec_ad':
    # mvtec dataset paper recommend 10% validation set
    train_size = int(0.9 * len(full_train_set))  # 训练集合中的10%用作验证
    validation_size = len(full_train_set) - train_size
    rng = torch.Generator().manual_seed(seed)
    train_set, validation_set = torch.utils.data.random_split(full_train_set,  # 将Dataset full_train_set分成训练集和验证集
                                                              [train_size,
                                                               validation_size],
                                                              rng)

2. dataloader 


train_loader = DataLoader(train_set, batch_size=1, shuffle=True,
                          num_workers=4, pin_memory=True)  # batch_size设置成1
# train_loader包了一层,使其可以无限取图
train_loader_infinite = InfiniteDataloader(train_loader)   # batchsize为1的可以无限获取数据
validation_loader = DataLoader(validation_set, batch_size=1)

3. model

3.1 teacher和student网络结构

(1)teacher和student网络结构是一样的,只是teacher的最后一个卷积输出通道翻倍。

(2)注意:没有bn操作。

(3)teacher和student两者结合用来检测local abnormal。

small pdn

# 4次卷积、2次avgPool。通道数:3->out_channels
def get_pdn_small(out_channels=384, padding=False):
    pad_mult = 1 if padding else 0  # 是否使用padding
    return nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4,
                  padding=3 * pad_mult),
        nn.ReLU(inplace=True),
        nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),

        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4,
                  padding=3 * pad_mult),
        nn.ReLU(inplace=True),
        nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),

        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,
                  padding=1 * pad_mult),
        nn.ReLU(inplace=True),

        nn.Conv2d(in_channels=256, out_channels=out_channels, kernel_size=4)
    )

medium pdn 

# 6次卷积、2次avgPool
def get_pdn_medium(out_channels=384, padding=False):
    pad_mult = 1 if padding else 0
    return nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=256, kernel_size=4,
                  padding=3 * pad_mult),
        nn.ReLU(inplace=True),
        nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),

        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4,
                  padding=3 * pad_mult),
        nn.ReLU(inplace=True),
        nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),

        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1 * pad_mult),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=4),
        nn.ReLU(inplace=True),

        nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
                  kernel_size=1)
    )

3.2 autoencoder

(1)6次下采样encode, 6次上采样decode. 有dropout,没有bn;

(2)autoencoder和student两者结合用来检测global abnormal。

# 6次卷积encode, 6次upsample+conv实现decode
def get_autoencoder(out_channels=384):
    return nn.Sequential(
        # encoder
        nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=32, out_channels=32, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=8),
        # decoder
        nn.Upsample(size=3, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),

        nn.Upsample(size=8, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),

        nn.Upsample(size=15, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),

        nn.Upsample(size=32, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),

        nn.Upsample(size=63, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),

        nn.Upsample(size=127, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),

        nn.Upsample(size=56, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1,
                  padding=1),
        nn.ReLU(inplace=True),

        nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=3,
                  stride=1, padding=1)
    )

3.3 create model 

在标准的S-T框架中,

(1)teacher是图像分类数据集的预训练模型,或者是这种预训练网络的蒸馏版本;

(2)student不是在预训练数据集上进行训练,而是在具体的应用项目中无异常图像上训练。

# 2. create models.
# medium多了两个卷积
# teacher和student的网络结构是一样的,不同的是student输出通道数翻倍,内部是最后一层卷积输出通道数不同。
if config.model_size == 'small':
    teacher = get_pdn_small(out_channels)  # out_channels=384
    student = get_pdn_small(2 * out_channels)  #
elif config.model_size == 'medium':
    teacher = get_pdn_medium(out_channels)
    student = get_pdn_medium(2 * out_channels)
else:
    raise Exception()

# 导入teacher预训练权重。
state_dict = torch.load(config.weights, map_location='cpu')
teacher.load_state_dict(state_dict)

# autoencoder网络的输出通道数是和teacher是一样的
autoencoder = get_autoencoder(out_channels)  # out_channels=384

# teacher frozen
teacher.eval()  # teacher迭代过程中不更新权重,用于指导student和autoencoder学习
student.train()
autoencoder.train()

if on_gpu:
    teacher.cuda()
    student.cuda()
    autoencoder.cuda()

4. train

4.1 求训练数据集的特征均值和标准差 

@torch.no_grad()
def teacher_normalization(teacher, train_loader):
    """
    这个函数的目的是为了计算teacher模型在当前项目中训练数据上的特征均值和标准差,
    以便在后续的训练中使用这些统计信息来标准化teacher模型和student模型的输出,以更好地进行知识蒸馏。
    """
    mean_outputs = []
    for train_image, _ in tqdm(train_loader, desc='Computing mean of features'):
        if on_gpu:
            train_image = train_image.cuda()  # size=(1, 3, 256, 256)
        teacher_output = teacher(train_image)  # teacher_output.size =(1, 384, 56, 56)
        # 计算每个通道的均值,dim=[0, 2, 3]表示在图像batch、高度和宽度上求均值
        mean_output = torch.mean(teacher_output, dim=[0, 2, 3])  # mean_output.size = (384,)
        mean_outputs.append(mean_output)
    channel_mean = torch.mean(torch.stack(mean_outputs), dim=0)  # (num_image,384) -> (384,)
    channel_mean = channel_mean[None, :, None, None]  # (384,) -> (1,384,1,1)

    mean_distances = []
    for train_image, _ in tqdm(train_loader, desc='Computing std of features'):
        if on_gpu:
            train_image = train_image.cuda()
        teacher_output = teacher(train_image)  # teacher_output.size=(1, 384, 56, 56)
        distance = (teacher_output - channel_mean) ** 2  # (x-mean)^2. distance.size=(1, 384, 56, 56)
        mean_distance = torch.mean(distance, dim=[0, 2, 3])  # (1, 384, 56, 56) -> (384,)
        mean_distances.append(mean_distance)
    channel_var = torch.mean(torch.stack(mean_distances), dim=0)  # 方差的均值
    channel_var = channel_var[None, :, None, None]
    channel_std = torch.sqrt(channel_var)  # 标准差

    return channel_mean, channel_std
# 获取当前训练集的特征均值和标准差,用于标准化teacher模型和student模型的输出,以更好地进行知识蒸馏。
teacher_mean, teacher_std = teacher_normalization(teacher, train_loader)
# 只有student,autoencoder参数是可以更新的
optimizer = torch.optim.Adam(itertools.chain(student.parameters(),
                                             autoencoder.parameters()),
                             lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=int(0.95 * config.train_steps), gamma=0.1)  # train_steps总共迭代次数,0.95时停止使用StepLR

4.2 student loss L_ST

(1)teacher_output_st = teacher(image_st);

(2)student_out_st = student(image_st)[:, :out_channels];

(3)student loss: 两个网络的输出,求差值的平方

distance_st = (teacher_output_st - student_out_st) ** 2。 distance_st.size = (1,384,56,56)

然后求distance_st中较大距离最大值的平均值作为loss。

tqdm_obj = tqdm(range(config.train_steps))
# train_loader_infinite: 返回两张图片,一张是默认transform后(teacher和student输入),一张是加了颜色数据增强后的(ae)
for iteration, (image_st, image_ae), image_penalty in zip(
        tqdm_obj, train_loader_infinite, penalty_loader_infinite):  # penalty_loader_infinite可以为None
    if on_gpu:
        image_st = image_st.cuda()
        image_ae = image_ae.cuda()
        if image_penalty is not None:
            image_penalty = image_penalty.cuda()
    with torch.no_grad():
        # 1. teacher prediction. image_st(1,3,256,256) -> teacher_output_st(1,384,56,56)
        teacher_output_st = teacher(image_st)  # 当前teacher输出,并标准化该输出。
        teacher_output_st = (teacher_output_st - teacher_mean) / teacher_std
    # 2. student prediction.   student_output_st.size=(1,384,56,56)
    student_output_st = student(image_st)[:, :out_channels]  # 输出特征的前面一半是student_output_st
    # student loss L_ST. local map
    distance_st = (teacher_output_st - student_output_st) ** 2  # 距离.
    d_hard = torch.quantile(distance_st, q=0.999)  # 通过计算分位数,可以确定距离的一个阈值,即 99.9% 的差异都小于或等于这个阈值。
    # 目的是计算出特征差异中相对较大的值所对应的损失。这些相对较大的值可能代表了比较困难的样本或错误较大的样本。
    loss_hard = torch.mean(distance_st[distance_st >= d_hard])  # 只使用较大的loss区域。

4.3  autoencoder loss (teacher and ae)

(1)image_ae和image_st是同样的图片,只不过image_ae多了一个颜色数据增强;

(2)teacher_output_ae = teacher(image_ae);

(3)ae_output = autoencoder(image_ae);

(4)autoencoder loss: 两个网络的diff.

# 3. autoencoder prediction
ae_output = autoencoder(image_ae)  # (1,3,256,256) -> (1,384,56,56)
with torch.no_grad():
    # 3.1 teacher prediction, 再标准化
    teacher_output_ae = teacher(image_ae)  # 输入的同样图像,并加上颜色数据增强得到image_ae
    teacher_output_ae = (teacher_output_ae - teacher_mean) / teacher_std

# autoencoder loss L_AE
distance_ae = (teacher_output_ae - ae_output) ** 2

loss_ae = torch.mean(distance_ae)

4.4 student additional loss (student and ae)

(1)image_ae和image_st是同样的图片,只不过image_ae多了一个颜色数据增强;

(2)ae_output = autoencoder(image_ae);

(3)student_output_ae = student(image_ae)[:, out_channels:];

(4)additional student loss: 两个网络的diff.

# 3.2 student prediction. 输出特征的后面一半是student_output_ae
student_output_ae = student(image_ae)[:, out_channels:]

# Student additional loss L_STAE.
distance_stae = (ae_output - student_output_ae) ** 2

loss_stae = torch.mean(distance_stae)

# total loss
loss_total = loss_st + loss_ae + loss_stae

optimizer.zero_grad()
loss_total.backward()
optimizer.step()
scheduler.step()

 


参考:GitHub - nelson1425/EfficientAD: Unofficial implementation of EfficientAD https://arxiv.org/abs/2303.14535


http://www.niftyadmin.cn/n/5122104.html

相关文章

企业级大数据处理实践——基于 Apache Flink

作者:禅与计算机程序设计艺术 1.简介 大数据领域正在经历一个百花齐放、草木皆兵的阶段,而Apache Flink作为当下最热门的开源大数据计算框架正在吸引越来越多的企业用户,帮助他们快速构建大数据平台,提升效率和价值。本文将从基础知识出发,通过Flink平台的实践案例,帮助…

SpringCloud复习:(3)LoadBalancerInterceptor

使用Ribbon时,execute方法会由RibbonLoadBalancerClient类来实现 它会调用重载的execute方法 getLoadBalancer默认会返回ZoneAwareLoadBalancer(基类是BaseLoadBalancer).此处调用的getServer方法就会根据负载均衡策略选择适当的服务器来为下一步的htt…

五、W5100S/W5500+RP2040树莓派Pico<UDP Client数据回环测试>

文章目录 1. 前言2. 协议简介2.1 简述2.2 优点2.3 应用 3. WIZnet以太网芯片4. UDP Client回环测试4.1 程序流程图4.2 测试准备4.3 连接方式4.4 相关代码4.5 测试现象 5. 注意事项6. 相关链接 1. 前言 UDP是一种无连接的网络协议,它提供了一种简单的、不可靠的方式来…

【Linux】安装配置虚拟机及虚拟机操作系统的安装

目录 一、操作系统 1. 介绍 2. 功能 3. 有哪些 4. 个人版本和服务器版本的区别 二、VMWare虚拟机 1. 安装 2. 配置 三、安装配置Windows Server 1. 配置 2. 安装 四、虚拟机的环境配置及连接 1. 主机连接虚拟机 2. 虚拟机环境配置及共享 3. 环境配置 一、操作系…

(PyTorch)PyTorch中的常见运算(*、@、Mul、Matmul)

1. 矩阵与标量 矩阵(张量)每一个元素与标量进行操作。 import torch a torch.tensor([1,2]) print(a1) >>> tensor([2, 3]) 2. 哈达玛积(Mul) 两个相同尺寸的张量相乘,然后对应元素的相乘就是这个哈达玛…

鱼眼图像去畸变python / c++

#鱼眼模型参考链接 本文假设去畸变后的图像与原图大小一样大。由于去畸变后的图像符合针孔投影模型,因此不同的去畸变焦距得到不同的视场大小,且物体的分辨率也不同。可以见上图,当焦距缩小为一半时,相同大小的图像(横…

Android 中如何使用 App Links

1. 简介 什么是 App Links呢?App Links 是 Android 6.0 (API 级别23) 引入的新功能,它是基于 DeepLinking,允许应用自动处理网站的 URL,而无需提示用户启动相应的应用。 例如:如果你在手机浏览器中输入了某个网站&am…

LeetCode分支-搜索插入位置

description 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 示例 1: 输入: nums [1,3,5,6], target 5 输出: 2…