附录2-tensorflow目标检测

news/2024/6/3 17:55:19 标签: 目标检测, tensorflow, 深度学习

源码来自作者Bubbliiiing,我对参考链接的代码略有修改,网盘地址

链接:百度网盘 请输入提取码 提取码:dvb1

目录

1  参考链接

2  环境

3  数据集准备

3.1  VOCdevkit/VOC2007

3.2  model_data/voc_classes.txt

3.3  voc_annotation.py

4  训练 train.py

5  训练结果

6  预测

7  其他

7.1  多线程训练

7.2  二次训练

7.3  学习速率


1  参考链接

源码地址 GitHub - bubbliiiing/yolo3-tf2: 这是一个yolo3-tf2的源码,可以用于训练自己的模型。

博客地址 睿智的目标检测51——Tensorflow2搭建yolo3目标检测平台_Bubbliiiing的博客-CSDN博客_yolo3

视频地址 睿智的目标检测51——Tensorflow2搭建yolo3目标检测平台_Bubbliiiing的博客-CSDN博客_yolo3

2  环境

  • 系统 Linux
  • 显卡 NVIDIA GeForce RTX 3060
  • CUDA 11.1
  • CUDNN 无 (cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2与cat /usr/local/cuda/include/cudnn.h | grep CUDNN_MAJOR -A 2都查不到)

python版本3.6,环境如下

我直接用这个whl装的,tensorflow_gpu-2.6.0-cp36-cp36m-manylinux2010_x86_64.whl

装完之后将keras降到了2.6.0

训练时默认使用GPU资源进行训练

项目放在home下,项目命名为tensorflow_object_detection

3  数据集准备

数据集为877张图像,4分类,其中speedlimit 705个框,crosswalk 174个框,traffclight 154个框,stop 88个框

3.1  VOCdevkit/VOC2007

在项目路径下的VOCdevkit/VOC2007中,将Annotations放入标注的XML文件,JPEGImages放入标注的图片文件(必须是jpg格式的图像,其他格式的不行)

进入ImageSets/Main,删除其中的所有内容

删除项目路径下的 2007_train.txt与2007_val.txt

3.2  model_data/voc_classes.txt

打开项目路径下model_data中的voc_classes.txt

将里面的内容改为自己要训练的类别,顺序无所谓

3.3  voc_annotation.py

不需要改动代码直接运行 voc_annotation.py

运行后会生成这些文件

4  训练 train.py

根据需要修改这里的epoch

然后直接运行就好了,一些warning可以无视掉

在训练开始的时候会给一些提示,可根据这里的提示修改上面的epoch,比如我现在就将epoch设置为569

  • 训练会持续很长事件

5  训练结果

训练结束后会在logs中出现一些文件,我们预测的时候使用 best_epoch_weights.h5 就可以了

我们可以在训练过程中,或者在训练好的loss文件中,查看loss情况

在epoch_loss.txt中可以查看具体的数值

  • 看下面这两个哪个都行

6  预测

我简单改了一下源代码中yolo.py的detect_image方法,目的是拿到预测的信息,而不是直接得到图像

  • 文件名改为了Suyu_yolo.py,下面的predict.py中会进行调用

然后改了一下源码中的predict.py(文件名我改为了Suyu_predict.py)

import time
import cv2
import numpy as np
import tensorflow as tf
from PIL import Image
from Suyu_yolo import YOLO
from utils.utils import get_classes

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

yolo = YOLO()

class_names,num_classes = get_classes('model_data/voc_classes.txt')
img = './img/road344.jpg'
image = Image.open(img)
out_boxes, out_scores, out_classes = yolo.detect_image(image)

result_img = cv2.imread(img)
for i, c in list(enumerate(out_classes)):
    predicted_class = class_names[int(c)]
    box = out_boxes[i]
    score = out_scores[i]

    top, left, bottom, right = box

    top = max(0, np.floor(top).astype('int32'))
    left = max(0, np.floor(left).astype('int32'))
    bottom = min(image.size[1], np.floor(bottom).astype('int32'))
    right = min(image.size[0], np.floor(right).astype('int32'))

    label = '{} {:.2f}'.format(predicted_class, score)
    print(label)

    cv2.rectangle(result_img,(left,top),(right,bottom),(0,255,0),2)
    cv2.putText(result_img,label,(left,top+5),cv2.FONT_HERSHEY_SIMPLEX,1,(255,0,0),2)

cv2.imshow('result_img',result_img)
cv2.waitKey(0)
cv2.destroyAllWindows()

之后我们将一张图像放在文件夹img中

之后运行predict.py就可以得到结果了

7  其他

7.1  多线程训练

将train.py中的num_workers置为0可以进行多线程训练

7.2  二次训练

每一次都从0开始训练耗费时间太多,所以我们需要对训练好的模型进行二次训练

首先读取一次训练,训练好的模型

将其更改为一次训练的epoch数

将其更改为最终的轮数,我上面初始写的500,这里写的1000,就表明再训练500轮

二次训练的初始loss值是根据你之前训练好的模型来的,所以初始的loss值不会像没训练过一样高(20多)

7.3  学习速率

训练结束后,如果我们发现loss值没有走低的趋势的时候(或训练过程中,我们可以停止训练,然后使用最近一次的h5文件进行二次训练二次训练),我们可以尝试降低学习率


http://www.niftyadmin.cn/n/62988.html

相关文章

如何明确同花顺l2接口的功能?

做同花顺l2接口首先是要明确做的这个api的功能,在设计出来之前就必须将同花顺l2接口的功能详细的整理出来,并且将各个调用模块分析划分出来,其次是同花顺l2接口的代码必须逻辑清晰,代码要有整洁性和便捷性,接口不能过于…

华为手表开发:WATCH 3 Pro(5)点击按钮弹窗

华为手表开发:WATCH 3 Pro(5)点击按钮弹窗初环境与设备创建项目认识目录结构修改首页 -> 新建按钮 “ 按钮 ”文件名:**index.hml**引用包:system.prompt点击结果初 鸿蒙可穿戴开发 希望能写一些简单的教程和案例…

公司来了个卷王,我愿称之为王中王,让人崩溃

前几天我们公司一下子也来了几个新人,这些年前人是真能熬啊,本来我们几个老油子都是每天稍微加会班就打算走了,这几个新人一直不走,搞得我们也不好走。2023年春招就要开始了,最近内卷严重,各种跳槽裁员&…

各种音频接口比较

时间 参考:https://www.bilibili.com/video/BV1SL4y1q7GZ/?spm_id_from333.337.search-card.all.click&vd_source00bd76f9d6dc090461cddd9f0deb2d51, https://blog.csdn.net/weixin_43794311/article/details/128941346 接口名字时间公司支持格式…

java 里的try - catch里面加了return后,finally还会执行吗?

1.try中添加return,其他地方不添加public class Demo {publicstatic void main(String[] args) {System.out.println(test());}publicstatic String test() {Stringstr "test";int i 5;Useruser new User();try {str "try";user.setUserName("张三&…

均衡负载集群(LBC)-1

均衡负载集群(LBC) 客户–>通过Internet—>负载调度器—>n台真实服务器 负载调度器: 软件:LVS;Nginx;Haproxy硬件:F5; LVS架构: 使用到C/S(B/S…

day30 泛型 Map Set Hash

文章目录相关概念codeGenericTest01 泛型机制GenericTest02GenericTest03 自定义泛型ForeachTest01ForEachTest02HashSetTest01TreeSetTest01MapTest01 常用方法MapTest02 遍历Map集合HashMapTest01HashMapTest02 同时重写hashCode和equals相关概念 day30课堂笔记 1、掌握Map…

从零实现高并发WebRTC服务器(六):OpenSSL协议,DTLS协议,RTP协议和SRTP协议

文章目录一、SSL协议二、OpenSSL三、TLS和DTLS四、DTLS的通信的步骤图五、RTP协议和SRTP协议5.1 详解RTP协议5.2 详解RTCP协议5.3 RTP && RTCP的协议的关键技术六、DTLS-SRTP协议一、SSL协议 SSL的全名叫做secure socket layer(安全套接字层),最开始是由一…