DETR 目标检测

news/2024/7/10 1:45:10 标签: 目标检测, 人工智能, 计算机视觉

DETR 目标检测

根据DETR官方源代码,写一个打框可可视化脚本(适用于NWPU-VHR-10数据集)

注意:
1、如果是自己的数据集,修改num_classes参数值为自己的数据种类类别 + 1
2、定义CLASSES和COLORS,每个类别对应一个颜色即可。
3、修改代码中的路径为自己的路径

可参考文章
https://blog.csdn.net/qq_45836365/article/details/128252220

import glob
import math
import argparse
import numpy as np
from models.detr import DETR
from models.backbone import Backbone, build_backbone
from models.transformer import build_transformer
from PIL import Image
import cv2
import requests
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
import torchvision.models as models

torch.set_grad_enabled(False)
import os


def get_args_parser():
    parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
    parser.add_argument('--lr', default=1e-4, type=float)
    parser.add_argument('--lr_backbone', default=1e-5, type=float)
    parser.add_argument('--batch_size', default=2, type=int)
    parser.add_argument('--weight_decay', default=1e-4, type=float)
    parser.add_argument('--epochs', default=300, type=int)
    parser.add_argument('--lr_drop', default=200, type=int)
    parser.add_argument('--clip_max_norm', default=0.1, type=float, help='gradient clipping max norm')
    # Model parameters
    parser.add_argument('--frozen_weights', type=str, default=None,
                        help="Path to the pretrained model. If set, only the mask head will be trained")  # * Backbone
    parser.add_argument('--backbone', default='resnet50', type=str, help="Name of the convolutional backbone to use")
    parser.add_argument('--dilation', action='store_true',
                        help="If true, we replace stride with dilation in the last convolutional block (DC5)")
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features")
    # * Transformer
    parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=2048, type=int,
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=256, type=int,
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=8, type=int,
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=100, type=int, help="Number of query slots")
    parser.add_argument('--pre_norm', action='store_true')
    # * Segmentation
    parser.add_argument('--masks', action='store_true', help="Train segmentation head if the flag is provided")
    # Loss
    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                        help="Disables auxiliary decoding losses (loss at each layer)")  # * Matcher
    parser.add_argument('--set_cost_class', default=1, type=float, help="Class coefficient in the matching cost")
    parser.add_argument('--set_cost_bbox', default=5, type=float, help="L1 box coefficient in the matching cost")
    parser.add_argument('--set_cost_giou', default=2, type=float,
                        help="giou box coefficient in the matching cost")  # * Loss coefficients
    parser.add_argument('--mask_loss_coef', default=1, type=float)
    parser.add_argument('--dice_loss_coef', default=1, type=float)
    parser.add_argument('--bbox_loss_coef', default=5, type=float)
    parser.add_argument('--giou_loss_coef', default=2, type=float)
    parser.add_argument('--eos_coef', default=0.1, type=float,
                        help="Relative classification weight of the no-object class")
    # dataset parameters
    parser.add_argument('--dataset_file', default='coco')
    parser.add_argument('--coco_path', type=str)
    parser.add_argument('--coco_panoptic_path', type=str)
    parser.add_argument('--remove_difficult', action='store_true')
    parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving')
    parser.add_argument('--device', default='cuda', help='device to use for training / testing')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--num_workers', default=2, type=int)
    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    return parser


CLASSES = ['airplane', 'ship', 'storage tank', 'baseball diamond', 'tennis court', 'basketball court',
           'ground track field', 'harbor', 'bridge', 'vehicle']
COLORS = [(120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50),

          (4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255),

          (230, 230, 230), (4, 250, 7),
          ]
transform_input = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


def plot_results(pil_img, prob, boxes, save_path):
    lw = max(round(sum(pil_img.shape) / 2 * 0.003), 2)
    tf = max(lw - 1, 1)
    colors = COLORS
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        c1 = p.argmax()
        text = f'{CLASSES[c1 - 1]}:{p[c1]:0.2f}'
        cv2.rectangle(pil_img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), colors[c1 - 1], thickness=lw,
                      lineType=cv2.LINE_AA)
        if text:
            tf = max(lw - 1, 1)
            w, h = cv2.getTextSize(text, 0, fontScale=lw / 3, thickness=tf)[0]
            cv2.rectangle(pil_img, (int(xmin), int(ymin)), (int(xmin) + w, int(ymin) - h - 3), colors[c1 - 1], -1,
                          cv2.LINE_AA)
            cv2.putText(pil_img, text, (int(xmin), int(ymin) - 2), 0, lw / 3, (255, 255, 255), thickness=tf,
                        lineType=cv2.LINE_AA)
    Image.fromarray(ori_img).save(save_path)


parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
backbone = build_backbone(args)
transform = build_transformer(args)
model = DETR(backbone=backbone, transformer=transform, num_classes=11, num_queries=100)
model_path = '/home/admin1/pywork/xuebing_pywork/detr-main/outs/checkpoint0299.pth'  # 保存的预训练好的模型pth文件,用于验证
model_data = torch.load(model_path)['model']
model.load_state_dict(model_data)
model.eval()

paths = os.listdir('/home/admin1/pywork/xuebing_pywork/mmdetection-main/data/coco/val2017')  # 待验证的图片路径
for path in paths:
    if os.path.splitext(path)[1] == ".png":
        im = cv2.imread(path)
        im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
    else:
        im = Image.open('/home/admin1/pywork/xuebing_pywork/mmdetection-main/data/coco/val2017' + '/' + path)
        # mean-std normalize the input image (batch-size: 1)
        img = transform_input(im).unsqueeze(0)
    # propagate through the model
    outputs = model(img)
    # keep only predictions with 0.9+ confidence
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > 0.9
    # convert boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    # 保存验证结果地址
    img_save_path = '/home/admin1/pywork/xuebing_pywork/detr-main/infer_results/' + \
                    os.path.splitext(os.path.split(path)[1])[0] + '.jpg'
    ori_img = np.array(im)
    plot_results(ori_img, probas[keep], bboxes_scaled, img_save_path)


http://www.niftyadmin.cn/n/5248915.html

相关文章

RHEL网络服务器

目录 1.时间同步的重要性 2.配置时间服务器 (1)指定所使用的上层时间服务器。 (2)指定允许访问的客户端 (3)把local stratum 前的注释符#去掉。 3.配置chrony客户端 (1)修改pool那行,指定要从哪台时间…

计算机存储结构分析(寄存器,内存,缓存,硬盘)

https://blog.csdn.net/bemodesty/article/details/81476906 前言 一个计算机包含多种存储器比如:寄存器、高速缓存、内存、硬盘、光盘等,为啥有这么多种存储方式,对于不太了解的人,总是觉得云里雾里的,搞不明白原因…

SQL命令---创建数据库

介绍 使用sql命令创建数据库。 命令 create database 数据库名称;

maven工程的pom.xml文件中增加了依赖,但偶尔没有下载到本地仓库

maven工程pom.xml文件中的个别依赖没有下载到本地maven仓库。以前没有遇到这种情况,今天就遇到了这个问题,把解决过程记录下来。 我在eclipse中编辑maven工程的pom.xml文件,增加对mybatis的依赖,但保存文件后,依赖的j…

C# WPF上位机开发(简易图像处理软件)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 图像处理是工业生产重要的环节。不管是定位、测量、检测还是识别,图像处理在工业生产中扮演重要的角色。而c#由于自身快速开发的特点&a…

基于ssm vue的风景文化管理平台源码和论文

摘 要 随着信息化时代的到来,管理系统都趋向于智能化、系统化,基于vue的木里风景文化管理平台也不例外,但目前国内的市场仍都使用人工管理,市场规模越来越大,同时信息量也越来越庞大,人工管理显然已无法应对…

AI:90-基于深度学习的自然灾害损害评估

🚀 本文选自专栏:人工智能领域200例教程专栏 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 ✨✨✨ 每一个案例都附带有在本地跑过的核心代码,详细讲解供大家学习,希望可以帮到大家。欢迎订阅支持,正在不断更新…

详解单链表OJ题

链表OJ经典题目 一.删除链表中等于给定值 val 的所有结点leetcode链接 二.给定一个带有头结点 head 的非空单链表,返回链表的中间结点。如果有两个中间结点,则返回第二个中间结点leetcode链接 三.反转一个单链表leetcode链接 四.输入一个链表&#xff0c…