FasterRCNN训练自己的数据集

2016年提出的Faster RCNN目标检测模型是深度学习现代目标检测算法的开山之作,也是第一个真正全流程都是神经网络目标检测模型。

其主要步骤如下:

1,使用CNN对输入图片提取feature map.

2,对feature map上的每个点设计一套不同大小和长宽比的anchor作为先验框。

3,设计RPN网络从大量的anchor中筛选出一些作为目标框的proposals并用回归分支纠正它们的位置。

4,使用ROI Pooling技术对不同大小的proposals获取相同大小的对应特征图,以便后续分类模型一并处理。

5,在proposals的feature map上使用分类分支和回归分支进一步预测目标类别和更精确的定位。

anchor技巧ROI Pooling技术 是非常值得学习的技巧,在许多目标检测模型中都能看到他们的身影。

b1c2e63ddc895a32120c5b0ddb03a614.jpeg

尽管FasterRCNN历史悠久,但依然是一个非常重要的目标检测任务的baseline.

一般会把它叫做two-stage的目标检测模型,主要是如果train from scratch,   RPN网络提取proposals和后续对propasals的定位分类 这两个步骤是要分开训练的,但在微调的时候,通常可以一起训练。

本文我们主要演示调用torchvision中的faster-rcnn模型在自己的数据集上微调来检测螺丝螺母。

#!pip install torchvision,torchkeras
import numpy as np
import pandas as pd 
from matplotlib import pyplot as plt
from PIL import Image,ImageColor,ImageDraw,ImageFont 

import torch
from torch import nn
import torchvision
from torchvision import datasets, models, transforms

import datetime
import os
import copy
import json 

print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2.0.0+cu117

〇,预训练模型

from torchkeras.data import get_example_image
img = get_example_image('park.jpg')
img.save('park.jpg')
from torchkeras.plots import vis_detection 

# 准备数据
inputs = []
img = Image.open('park.jpg').convert("RGB")
img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float()
if torch.cuda.is_available():
    img_tensor = img_tensor.cuda()
inputs.append(img_tensor)    

# 加载模型
num_classes = 91
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1,
    num_classes = num_classes)

if torch.cuda.is_available():
    model.to("cuda:0")
model.eval()

# 预测结果
with torch.no_grad():
    predictions = model(inputs)


# 结果可视化
class_names = torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1.meta['categories']

vis_detection(img,predictions[0],class_names,min_score = 0.8)

ed52a62468ca74dbcbd773733d61e7c5.png

下面代码我们演示使用我开发的优雅的torchkeras工具在自己的数据集上对Faster-RCNN模型进行finetune。

我们使用一个非常简单的螺丝(bolt)螺母(nut)数据集作为示范。

公众号 算法美食屋 后台回复关键词:torchkeras,获取本文notebook代码和 bolt nut 数据集 下载地址。

一,准备数据

data_path = "./data/bolt_nut"

train_images_path = "./data/bolt_nut/train"
train_targets_path = './data/bolt_nut/train.txt'

val_images_path = "./data/bolt_nut/val"
val_targets_path = './data/bolt_nut/val.txt'

class_names = ['__background__','bolt','nut']
class BoltNut(torch.utils.data.Dataset):
    def __init__(self, images_path, targets_path, 
                 class_names = class_names,
                 transforms = None
                ):
        self.images_path = images_path
        self.targets_path = targets_path
        self.transforms = transforms
        self.infos_list = open(targets_path,"r").readlines()
        self.class_names = class_names

    def __getitem__(self, idx):
        
        info_str = self.infos_list[idx]
        info_arr = info_str.replace("\n","").replace("\t ","").split("\t")
        
        img_path = info_arr.pop(0)
        
        info_arr = [x for x in info_arr if x.strip()] 
        infos = [json.loads(x) for x in info_arr]

        img= Image.open(os.path.join(self.images_path,img_path)).convert("RGB")

        target = {}
        target["image_id"] = torch.tensor([int(img_path.split(".")[0])],dtype = torch.int64)  
        target["labels"] = torch.tensor([self.class_names.index(x["value"]) for x in infos],
                                        dtype = torch.int64)

        coords = [x["coordinate"]  for x in infos]
        boxes = torch.tensor([[xmin,ymin,xmax,ymax] for (xmin,ymin), (xmax,ymax)  in coords])
        target["boxes"] = boxes

        target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        target["iscrowd"] = torch.zeros((len(infos) ,), dtype=torch.int64)
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)
  
        return img, target

    def __len__(self):
        return len(self.infos_list)
# 可视化数据集
ds_train = BoltNut(train_images_path,train_targets_path)
img,target = ds_train[12]

target["scores"] = torch.ones_like(target["labels"])
img_result = vis_detection(img,target,class_names,min_score = 0.8)
img_result

0f773c51efefd49c9ef239c74d8c1302.png

下面我们设计数据增强模块

import random 
from torchvision import transforms as T

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class RandomHorizontalFlip(object):
    def __init__(self, prob):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
            if "masks" in target:
                target["masks"] = target["masks"].flip(-1)
        return image, target


class ToTensor(object):
    def __call__(self, image, target):
        image = T.ToTensor()(image)
        return image, target
transforms_train = Compose([ToTensor(),RandomHorizontalFlip(0.5)])
transforms_val = ToTensor()

ds_train = BoltNut(train_images_path,train_targets_path,transforms=transforms_train)
ds_val = BoltNut(val_images_path,val_targets_path,transforms=transforms_val)
def collate_fn(batch):
      return tuple(zip(*batch))

dl_train = torch.utils.data.DataLoader(ds_train, batch_size=2, 
          shuffle=True, num_workers=4,collate_fn= collate_fn)

dl_val = torch.utils.data.DataLoader(ds_val, batch_size=2, 
          shuffle=True, num_workers=4,collate_fn= collate_fn)
for batch in dl_train:
    features,labels = batch  
    break

二,定义模型

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

num_classes = 3  # 3 classes (bult,nut) + background
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights=torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

三,训练模型

from torchkeras import KerasModel
class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.train() #attention here
    
    def __call__(self, batch):
        features,labels = batch 
        
        #loss
        loss_dict = self.net(features,labels)
        loss = sum(loss_dict.values())
        
        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        #all_preds = self.accelerator.gather(preds)
        #all_labels = self.accelerator.gather(labels)
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
KerasModel.StepRunner = StepRunner
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                             momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=4)

keras_model = KerasModel(model,
                         loss_fn = None,
                         metrics_dict=None,
                         optimizer= optimizer,
                         lr_scheduler=lr_scheduler
                        )

keras_model.fit(train_data=dl_train,val_data=dl_val,
    epochs=20,patience=5,
    monitor='val_loss',
    mode='min',
    ckpt_path ='faster-rcnn.pt',
    plot=True
)

d299f12f273b84c75b01062c0740f2b0.png

024eaa1bd0ac0807dfeea0bdc3963d7c.png

四,评估模型

import torch 

from PIL import Image 
from tqdm import tqdm
from ultralytics.yolo.utils import set_logging
set_logging(verbose=False)
from ultralytics.yolo.utils.metrics import  DetMetrics, box_iou
def process_batch(predictions, targets, 
                  iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
                 ):
    ...
    return metrics
model.eval()
list_predictions = [model(x[0].to('cuda')[None,...])[0] for x in ds_val]
list_targets = [x[1] for x in ds_val]

names = {0:'bolt',1:'nut'}
metrics = eval_metrics(list_predictions,
                       list_targets,
                       names =  names)
display(metrics.results_dict)
{'metrics/precision(B)': 0.9976781395819151,
 'metrics/recall(B)': 1.0,
 'metrics/mAP50(B)': 0.995,
 'metrics/mAP50-95(B)': 0.8542317510036526,
 'fitness': 0.8683085759032874}
import pandas as pd 
df = pd.DataFrame()
df['metric'] = metrics.keys
for i,c in names.items():
    df[c] = metrics.class_result(i)
df

9dfe2e3a0af99c86ab178b6e85870002.png

五,使用模型

# 准备数据
inputs = []
img_path = os.path.join(val_images_path,os.listdir(val_images_path)[5])
img = Image.open(img_path).convert("RGB")
img_tensor = torch.from_numpy(np.array(img)/255.).permute(2,0,1).float()
if torch.cuda.is_available():
    img_tensor = img_tensor.cuda()
inputs.append(img_tensor)    

model.eval()

# 预测结果
with torch.no_grad():
    predictions = model(inputs)

# 结果可视化
vis_detection(img,predictions[0],list(idx2names.values()),min_score = 0.8)

38fed89e054034ac169df0251abc66d5.png

公众号 算法美食屋 后台回复关键词:torchkeras,获取本文notebook代码和 bolt nut 数据集 下载地址。

万水千山总是情,点个赞赞行不行?😋😋


http://www.niftyadmin.cn/n/346745.html

相关文章

Matlab - Plot in plot(图中画图)

Matlab - Plot in plot&#xff08;图中画图&#xff09; 这是在MATLAB中创建一个嵌入式图形的示例&#xff0c;可以在另一个图形中显示。 与MATLAB中的“axes”函数相关。 Coding % Create data t linspace(0,2*pi); t(1) eps; y sin(t);% Place axes at (0.1,0.1) with w…

判断传入数据是否为列表、数组、数据框等数据结构pd.api.types.is_list_like()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 判断传入数据是否为 列表、数组、数据框等数据结构 pd.api.types.is_list_like() 选择题 下列说法错误的是? import pandas as pd import numpy as np print("【执行】pd.api.ty…

GO开篇:手握Java走进Golang的世界

文章目录 一、Golang简介1、Go的诞生2、Go的官网域名3、Go的发展4、Go的设计思想5、Go的特点6、Go的性能7、Go的吉祥物 二、Go和Java的宏观对比1、编译型语言 or 解释型语言2、微观对比 三、Go应用场景1、开源上的应用 四、总结和后续 一、Golang简介 Go&#xff08;又称 Gola…

随想录训练营38/60 | 完全背包;LC 518. 零钱兑换 II;LC 377. 组合总和 Ⅳ

完全背包 什么是完全背包&#xff1f; 完全背包和01背包的区别就是&#xff0c;完全背包能将某个物品添加无数次。 在二维dp数组迭代更新中体现为&#xff1a; 01背包dp数组由左上面的数组更新而成&#xff1b; 完全背包do数组由包括本行在内的左边的数组更新而成。 在一维dp数…

VIBRO-METER VM600 IRC4 智能继电器卡

额外的继电器&#xff0c;由来自MPC4和/或AMC8卡的多达86个输入的方程驱动&#xff0c;用于需要2oo3表决等更复杂的逻辑时8个继电器&#xff0c;可配置为8个SPDT或4个DPDT使用IRC4配置器软件进行完全软件配置继电器可配置为正常通电(NE)或正常断电(NDE)&#xff0c;具有可配置的…

Go语言面试题--必会语法(2)

文章目录 1.函数执行时&#xff0c;如果由于 panic 导致了异常&#xff0c;则延迟函数不会执行。这一说法是否正确&#xff1f;2.下面代码输出什么&#xff1f;3.下面这段代码输出什么&#xff1f;请简要说明。4.下面代码输出什么&#xff1f; 1.函数执行时&#xff0c;如果由于…

MQTT Part 5 主题和最佳实践

主题 主题是一个UTF-8字符串&#xff0c;代理用它来过滤每个连接的客户端的消息。 主题由一个或多个主题级别组成。 每个主题级别之间由正斜杠&#xff08;主题级别分隔符&#xff09;分隔。 与消息队列相比&#xff0c;主题非常轻量级。 客户端不需要在发布或订阅之前创建所需…

linux启动后端服务

1、下载jdk wget --no-cookies --no-check-certificate --header "Cookie: gpw_e24http%3A%2F%2Fwww.oracle.com%2F; oraclelicenseaccept-securebackup-cookie" "http://download.oracle.com/otn-pub/java/jdk/8u141-b15/336fa29ff2bb4ef291e347e091f7f4a7/jd…