目录
1. dataset
2. dataloader
3. model
3.1 teacher和student网络结构
3.2 autoencoder
3.3 create model
4. train
4.1 求训练数据集的特征均值和标准差
4.2 student loss L_ST
4.3 autoencoder loss (teacher and ae)
4.4 student additional loss (student and ae)
1. dataset
# TrainDataset只是使用图片,对图片做一些transform后返回
class ImageFolderWithoutTarget(ImageFolder):
def __getitem__(self, index):
sample, target = super().__getitem__(index)
return sample
# TestDataset对图片做一些transform后,返回: 图片、类别、文件路径
class ImageFolderWithPath(ImageFolder):
def __getitem__(self, index):
path, target = self.samples[index]
sample, target = super().__getitem__(index)
return sample, target, path
# data transform
default_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_ae = transforms.RandomChoice([
transforms.ColorJitter(brightness=0.2),
transforms.ColorJitter(contrast=0.2),
transforms.ColorJitter(saturation=0.2)
])
# student和teacher使用默认的transform
# autoencoder网络在默认的基础上增加了一个颜色数据增强
def train_transform(image):
return default_transform(image), default_transform(transform_ae(image))
# 1, load data
full_train_set = ImageFolderWithoutTarget( # 将图片缩放到固定尺度,只是返回图片。
os.path.join(dataset_path, config.subdataset, 'train'),
transform=transforms.Lambda(train_transform)) # 返回两张图片,一张是默认transform后,一张是加了颜色数据增强后的
test_set = ImageFolderWithPath( # 返回原始图片、对应的类别:broken_large、broken_small、contamination、good。
os.path.join(dataset_path, config.subdataset, 'test'))
if config.dataset == 'mvtec_ad':
# mvtec dataset paper recommend 10% validation set
train_size = int(0.9 * len(full_train_set)) # 训练集合中的10%用作验证
validation_size = len(full_train_set) - train_size
rng = torch.Generator().manual_seed(seed)
train_set, validation_set = torch.utils.data.random_split(full_train_set, # 将Dataset full_train_set分成训练集和验证集
[train_size,
validation_size],
rng)
2. dataloader
train_loader = DataLoader(train_set, batch_size=1, shuffle=True,
num_workers=4, pin_memory=True) # batch_size设置成1
# train_loader包了一层,使其可以无限取图
train_loader_infinite = InfiniteDataloader(train_loader) # batchsize为1的可以无限获取数据
validation_loader = DataLoader(validation_set, batch_size=1)
3. model
3.1 teacher和student网络结构
(1)teacher和student网络结构是一样的,只是teacher的最后一个卷积输出通道翻倍。
(2)注意:没有bn操作。
(3)teacher和student两者结合用来检测local abnormal。
small pdn
# 4次卷积、2次avgPool。通道数:3->out_channels
def get_pdn_small(out_channels=384, padding=False):
pad_mult = 1 if padding else 0 # 是否使用padding
return nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4,
padding=3 * pad_mult),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4,
padding=3 * pad_mult),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,
padding=1 * pad_mult),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=256, out_channels=out_channels, kernel_size=4)
)
medium pdn
# 6次卷积、2次avgPool
def get_pdn_medium(out_channels=384, padding=False):
pad_mult = 1 if padding else 0
return nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=256, kernel_size=4,
padding=3 * pad_mult),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4,
padding=3 * pad_mult),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1 * pad_mult),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=4),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
kernel_size=1)
)
3.2 autoencoder
(1)6次下采样encode, 6次上采样decode. 有dropout,没有bn;
(2)autoencoder和student两者结合用来检测global abnormal。
# 6次卷积encode, 6次upsample+conv实现decode
def get_autoencoder(out_channels=384):
return nn.Sequential(
# encoder
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=8),
# decoder
nn.Upsample(size=3, mode='bilinear'),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
padding=2),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Upsample(size=8, mode='bilinear'),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
padding=2),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Upsample(size=15, mode='bilinear'),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
padding=2),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Upsample(size=32, mode='bilinear'),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
padding=2),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Upsample(size=63, mode='bilinear'),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
padding=2),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Upsample(size=127, mode='bilinear'),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
padding=2),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Upsample(size=56, mode='bilinear'),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1,
padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=3,
stride=1, padding=1)
)
3.3 create model
在标准的S-T框架中,
(1)teacher是图像分类数据集的预训练模型,或者是这种预训练网络的蒸馏版本;
(2)student不是在预训练数据集上进行训练,而是在具体的应用项目中无异常图像上训练。
# 2. create models.
# medium多了两个卷积
# teacher和student的网络结构是一样的,不同的是student输出通道数翻倍,内部是最后一层卷积输出通道数不同。
if config.model_size == 'small':
teacher = get_pdn_small(out_channels) # out_channels=384
student = get_pdn_small(2 * out_channels) #
elif config.model_size == 'medium':
teacher = get_pdn_medium(out_channels)
student = get_pdn_medium(2 * out_channels)
else:
raise Exception()
# 导入teacher预训练权重。
state_dict = torch.load(config.weights, map_location='cpu')
teacher.load_state_dict(state_dict)
# autoencoder网络的输出通道数是和teacher是一样的
autoencoder = get_autoencoder(out_channels) # out_channels=384
# teacher frozen
teacher.eval() # teacher迭代过程中不更新权重,用于指导student和autoencoder学习
student.train()
autoencoder.train()
if on_gpu:
teacher.cuda()
student.cuda()
autoencoder.cuda()
4. train
4.1 求训练数据集的特征均值和标准差
@torch.no_grad()
def teacher_normalization(teacher, train_loader):
"""
这个函数的目的是为了计算teacher模型在当前项目中训练数据上的特征均值和标准差,
以便在后续的训练中使用这些统计信息来标准化teacher模型和student模型的输出,以更好地进行知识蒸馏。
"""
mean_outputs = []
for train_image, _ in tqdm(train_loader, desc='Computing mean of features'):
if on_gpu:
train_image = train_image.cuda() # size=(1, 3, 256, 256)
teacher_output = teacher(train_image) # teacher_output.size =(1, 384, 56, 56)
# 计算每个通道的均值,dim=[0, 2, 3]表示在图像batch、高度和宽度上求均值
mean_output = torch.mean(teacher_output, dim=[0, 2, 3]) # mean_output.size = (384,)
mean_outputs.append(mean_output)
channel_mean = torch.mean(torch.stack(mean_outputs), dim=0) # (num_image,384) -> (384,)
channel_mean = channel_mean[None, :, None, None] # (384,) -> (1,384,1,1)
mean_distances = []
for train_image, _ in tqdm(train_loader, desc='Computing std of features'):
if on_gpu:
train_image = train_image.cuda()
teacher_output = teacher(train_image) # teacher_output.size=(1, 384, 56, 56)
distance = (teacher_output - channel_mean) ** 2 # (x-mean)^2. distance.size=(1, 384, 56, 56)
mean_distance = torch.mean(distance, dim=[0, 2, 3]) # (1, 384, 56, 56) -> (384,)
mean_distances.append(mean_distance)
channel_var = torch.mean(torch.stack(mean_distances), dim=0) # 方差的均值
channel_var = channel_var[None, :, None, None]
channel_std = torch.sqrt(channel_var) # 标准差
return channel_mean, channel_std
# 获取当前训练集的特征均值和标准差,用于标准化teacher模型和student模型的输出,以更好地进行知识蒸馏。
teacher_mean, teacher_std = teacher_normalization(teacher, train_loader)
# 只有student,autoencoder参数是可以更新的
optimizer = torch.optim.Adam(itertools.chain(student.parameters(),
autoencoder.parameters()),
lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=int(0.95 * config.train_steps), gamma=0.1) # train_steps总共迭代次数,0.95时停止使用StepLR
4.2 student loss L_ST
(1)teacher_output_st = teacher(image_st);
(2)student_out_st = student(image_st)[:, :out_channels];
(3)student loss: 两个网络的输出,求差值的平方
distance_st = (teacher_output_st - student_out_st) ** 2。 distance_st.size = (1,384,56,56)
然后求distance_st中较大距离最大值的平均值作为loss。
tqdm_obj = tqdm(range(config.train_steps))
# train_loader_infinite: 返回两张图片,一张是默认transform后(teacher和student输入),一张是加了颜色数据增强后的(ae)
for iteration, (image_st, image_ae), image_penalty in zip(
tqdm_obj, train_loader_infinite, penalty_loader_infinite): # penalty_loader_infinite可以为None
if on_gpu:
image_st = image_st.cuda()
image_ae = image_ae.cuda()
if image_penalty is not None:
image_penalty = image_penalty.cuda()
with torch.no_grad():
# 1. teacher prediction. image_st(1,3,256,256) -> teacher_output_st(1,384,56,56)
teacher_output_st = teacher(image_st) # 当前teacher输出,并标准化该输出。
teacher_output_st = (teacher_output_st - teacher_mean) / teacher_std
# 2. student prediction. student_output_st.size=(1,384,56,56)
student_output_st = student(image_st)[:, :out_channels] # 输出特征的前面一半是student_output_st
# student loss L_ST. local map
distance_st = (teacher_output_st - student_output_st) ** 2 # 距离.
d_hard = torch.quantile(distance_st, q=0.999) # 通过计算分位数,可以确定距离的一个阈值,即 99.9% 的差异都小于或等于这个阈值。
# 目的是计算出特征差异中相对较大的值所对应的损失。这些相对较大的值可能代表了比较困难的样本或错误较大的样本。
loss_hard = torch.mean(distance_st[distance_st >= d_hard]) # 只使用较大的loss区域。
4.3 autoencoder loss (teacher and ae)
(1)image_ae和image_st是同样的图片,只不过image_ae多了一个颜色数据增强;
(2)teacher_output_ae = teacher(image_ae);
(3)ae_output = autoencoder(image_ae);
(4)autoencoder loss: 两个网络的diff.
# 3. autoencoder prediction
ae_output = autoencoder(image_ae) # (1,3,256,256) -> (1,384,56,56)
with torch.no_grad():
# 3.1 teacher prediction, 再标准化
teacher_output_ae = teacher(image_ae) # 输入的同样图像,并加上颜色数据增强得到image_ae
teacher_output_ae = (teacher_output_ae - teacher_mean) / teacher_std
# autoencoder loss L_AE
distance_ae = (teacher_output_ae - ae_output) ** 2
loss_ae = torch.mean(distance_ae)
4.4 student additional loss (student and ae)
(1)image_ae和image_st是同样的图片,只不过image_ae多了一个颜色数据增强;
(2)ae_output = autoencoder(image_ae);
(3)student_output_ae = student(image_ae)[:, out_channels:];
(4)additional student loss: 两个网络的diff.
# 3.2 student prediction. 输出特征的后面一半是student_output_ae
student_output_ae = student(image_ae)[:, out_channels:]
# Student additional loss L_STAE.
distance_stae = (ae_output - student_output_ae) ** 2
loss_stae = torch.mean(distance_stae)
# total loss
loss_total = loss_st + loss_ae + loss_stae
optimizer.zero_grad()
loss_total.backward()
optimizer.step()
scheduler.step()
参考:GitHub - nelson1425/EfficientAD: Unofficial implementation of EfficientAD https://arxiv.org/abs/2303.14535