COCO评估输出指定某类AP或者输出每个类别AP结果

news/2024/7/10 0:17:55 标签: 深度学习, 目标检测, python

一 输出单类AP(不需要修改pycocotools)

coco_eval.py源代码

python">"""
COCO-Style Evaluations

"""

import json
import os

import argparse
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

def eval(coco_gt, image_ids, pred_json_path):
    # load results in COCO evaluation tool
    coco_pred = coco_gt.loadRes(pred_json_path)

    # run COCO evaluation
    print('BBox')
    coco_eval = COCOeval(coco_gt, coco_pred, 'bbox')
    coco_eval.params.imgIds = image_ids
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

if __name__ == '__main__':
    ap = argparse.ArgumentParser()
    ap.add_argument('--gt-json', type=str, default='instances.json', help='coco val2017 annotations json files')
    ap.add_argument('--pred-json', type=str, default='coco_results.json', help='pred coco val2017 annotations json files')
    args = ap.parse_args()
    print(args)

    pred_json_path = args.pred_json

    MAX_IMAGES = 1000
    coco_gt = COCO(args.gt_json)
    image_ids = coco_gt.getImgIds()[:MAX_IMAGES]

    eval(coco_gt, image_ids, pred_json_path)

以上代码评估后的输出

 

修改后的代码,可以指定输出某类AP值,只修改eval函数:

python">def eval(coco_gt, image_ids, pred_json_path):
    # load results in COCO evaluation tool
    coco_pred = coco_gt.loadRes(pred_json_path)

    # run COCO evaluation
    print('BBox')
    coco_eval = COCOeval(coco_gt, coco_pred, 'bbox')
    coco_eval.params.imgIds = image_ids
    coco_eval.params.catIds = [2] # 你可以根据需要增减类别
    coco_eval.evaluate()
    coco_eval.accumulate()
    coco_eval.summarize()

修改后输出,比如只输出第2类

 

二 输出每个类别AP结果(需要修改pycocotools下的coco.py和cocoeval.py)

首先修改coco.py的类COCO的初始化为,在84行下添加代码

python">    def __init__(self, annotation_file=None):
        """
        Constructor of Microsoft COCO helper class for reading and visualizing annotations.
        :param annotation_file (str): location of annotation file
        :param image_folder (str): location to the folder that hosts images.
        :return:
        """
        # load dataset
        self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
        self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
        if not annotation_file == None:
            print('loading annotations into memory...')
            tic = time.time()
            with open(annotation_file, 'r') as f:
                dataset = json.load(f)
            assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
            print('Done (t={:0.2f}s)'.format(time.time()- tic))
            print(
                "category names: {}".format([e["name"] for e in sorted(dataset["categories"], key=lambda x: x["id"])]))
            self.dataset = dataset
            self.createIndex()

修改cocoeval.py,在第456行下添加代码,修改summarize函数

python">def summarize(self):
        '''
        Compute and display summary metrics for evaluation results.
        Note this functin can *only* be applied on the default parameter setting
        '''
        def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
            p = self.params
            iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
            titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
            typeStr = '(AP)' if ap==1 else '(AR)'
            iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
                if iouThr is None else '{:0.2f}'.format(iouThr)

            aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
            mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
            if ap == 1:
                # dimension of precision: [TxRxKxAxM]
                s = self.eval['precision']
                # IoU
                if iouThr is not None:
                    t = np.where(iouThr == p.iouThrs)[0]
                    s = s[t]
                s = s[:,:,:,aind,mind]
            else:
                # dimension of recall: [TxKxAxM]
                s = self.eval['recall']
                if iouThr is not None:
                    t = np.where(iouThr == p.iouThrs)[0]
                    s = s[t]
                s = s[:,:,aind,mind]
            if len(s[s>-1])==0:
                mean_s = -1
            else:
                mean_s = np.mean(s[s>-1])
            #print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
            category_dimension = 1 + int(ap)
            if s.shape[category_dimension] > 1:

                iStr += ", per category = {}"
                mean_axis = (0,)
                if ap == 1:
                    mean_axis = (0, 1)
                per_category_mean_s = np.mean(s, axis=mean_axis).flatten()
                with np.printoptions(precision=3, suppress=True, sign=" ", floatmode="fixed"):
                    print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s, per_category_mean_s))
            else:
                print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s, ""))
            return mean_s

输出结果为(我的是6类模型):

 


http://www.niftyadmin.cn/n/1186016.html

相关文章

SDUT-2120 数据结构实验之链表五:单链表的拆分

数据结构实验之链表五&#xff1a;单链表的拆分Time Limit: 1000MS Memory Limit: 65536KBSubmit StatisticProblem Description输入N个整数顺序建立一个单链表&#xff0c;将该单链表拆分成两个子链表&#xff0c;第一个子链表存放了所有的偶数&#xff0c;第二个子链表存放了…

基于PaddlePaddle的工业表计数环境搭建

为了验证百度PaddlePaddle发布的工业表计数工程模型的准确性以及效果&#xff0c;分别在PC端和jetson端搭建了环境&#xff0c;亲测实际效果 工业表计数&#xff1a;工业表计读数 — PaddleX 文档 链接中jetson nano c部署链接失效&#xff0c;可以参考这个&#xff1a; Jets…

SDUT-2131 数据结构实验之栈与队列一:进制转换

数据结构实验之栈一&#xff1a;进制转换Time Limit: 1000MS Memory Limit: 65536KBSubmit StatisticProblem Description输入一个十进制非负整数&#xff0c;将其转换成对应的 R (2 < R < 9) 进制数&#xff0c;并输出。Input第一行输入需要转换的十进制非负整数&#x…

瑞芯微RV1126/1109开发流程之MPP部署

本文根据RKNN交流群提供的MPP开源代码&#xff0c;在RV1126上部署MPP demo&#xff0c;MPP的GitHub地址为&#xff1a; GitHub - rockchip-linux/mpp: Media Process Platform (MPP) module 本github下载下来的并不是只针对rv1126的&#xff0c;所以某些参数需要更改 1、更改…

SDUT-2134 数据结构实验之栈与队列四:括号匹配

数据结构实验之栈与队列四&#xff1a;括号匹配Time Limit: 1000MS Memory Limit: 65536KBSubmit StatisticProblem Description给你一串字符&#xff0c;不超过50个字符&#xff0c;可能包括括号、数字、字母、标点符号、空格&#xff0c;你的任务是检查这一串字符中的( ) ,[ …

SDUT-3343 数据结构实验之二叉树四:(先序中序)还原二叉树

数据结构实验之二叉树四&#xff1a;还原二叉树Time Limit: 1000MS Memory Limit: 65536KBSubmit Statistic DiscussProblem Description给定一棵二叉树的先序遍历序列和中序遍历序列&#xff0c;要求计算该二叉树的高度。Input输入数据有多组&#xff0c;每组数据第一行输入1个…

SDUT-2482 二叉排序树

二叉排序树Time Limit: 1000MS Memory Limit: 65536KBSubmit Statistic DiscussProblem Description二叉排序树的定义是&#xff1a;或者是一棵空树&#xff0c;或者是具有下列性质的二叉树&#xff1a; 若它的左子树不空&#xff0c;则左子树上所有结点的值均小于它的根结点的…

SDUT-2127 树-堆结构练习——合并果子之哈夫曼树

树-堆结构练习——合并果子之哈夫曼树Time Limit: 1000MS Memory Limit: 65536KBSubmit Statistic DiscussProblem Description在一个果园里&#xff0c;多多已经将所有的果子打了下来&#xff0c;而且按果子的不同种类分成了不同的堆。多多决定把所有的果子合成一堆。每一次合…