TensorFlow目标检测api训练mobilenet ssd教程

2020-03-12 07:44:51   计算机视觉

install anaconda

安装依赖

sudo apt-get install libgl1-mesa-glx libegl1-mesa libxrandr2 libxrandr2 libxss1 libxcursor1 libxcomposite1 libasound2 libxi6 libxtst6

下载安装文件

下载地址 python3.7版本地址

安装

chmod +x Anaconda3-2019.10-Linux-x86_64.sh
./Anaconda3-2019.10-Linux-x86_64.sh

开始是一堆霸王条款,一直回车,然后输入yes

Please answer 'yes' or 'no':'
>>> yes

然后选择安装路径,默认就好,直接回车,等待安装即可。

新建环境

conda create --name obj_detection python=3.6

激活环境

conda activate obj_detection

download tensorflow models

git clone https://github.com/tensorflow/models.git

install tf requirements

python

pip install Cython contextlib2 pillow lxml jupyter matplotlib absl-py tensorflow==1.14

注:由于默认下载最新的tensorflow2,由于keras为tf2的御用接口框架,于是slim就找不到了,而且包的结构发生了翻天覆地的变化,找不到tf.contrib了,说白了tf2和tf1是两个不兼容的框架,而该项目是用tf1写的,所以我们指定一个tf1的版本。

当然如果你的机器有gpu想用gpu训练的话,请安装gpu版的tensorflow

pip install tensorflow-gpu==1.14

protoc

下载地址

解压后改名字为protoc,然后复制到刚刚下载的models项目的research目录下,然后在research目录下运行:

./protoc/bin/protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
python setup.py build
python setup.py install

运行如下Python文件检验环境是否正确

python object_detection/builders/model_builder_test.py

成功的话是这样的:

[此处省略部分打印信息...]
Ran 17 tests in 0.228s

OK (skipped=1)

prepare dataset

  1. 数据集分为训练、验证、测试

    import os
    import random
    import time
    import shutil
    
    # 所有标注的xml文件目录
    xmlfilepath = r'/home/zheshi/h4tv/datasets/VOCdevkit/VOCBall2k/Annotations'
    # 分割数据集后存储的xml文件路径,会存放在这个路径下分成train validation test三个目录
    saveBasePath = r"/home/zheshi/h4tv/datasets/VOCdevkit/VOCBall2k/VienDataset/Annotations"
    
    # 分割比例
    trainval_percent = 0.9
    train_percent = 0.85
    
    total_xml = os.listdir(xmlfilepath)
    num = len(total_xml)
    list = range(num)
    tv = int(num * trainval_percent)
    tr = int(tv * train_percent)
    trainval = random.sample(list, tv)
    train = random.sample(trainval, tr)
    print("train and val size", tv)
    print("train size", tr)
    start = time.time()
    test_num = 0
    val_num = 0
    train_num = 0
    
    for i in list:
       name = total_xml[i]
       if i in trainval:  # train and val set
           if i in train:
               directory = "train"
               train_num += 1
               xml_path = os.path.join(saveBasePath, directory)
               print(xml_path)
               if (not os.path.exists(xml_path)):
                   os.mkdir(xml_path)
               filePath = os.path.join(xmlfilepath, name)
               newfile = os.path.join(saveBasePath, os.path.join(directory, name))
               shutil.copyfile(filePath, newfile)
           else:
               directory = "validation"
               xml_path = os.path.join(saveBasePath, directory)
               print(xml_path)
               if (not os.path.exists(xml_path)):
                   os.mkdir(xml_path)
               val_num += 1
               filePath = os.path.join(xmlfilepath, name)
               newfile = os.path.join(saveBasePath, os.path.join(directory, name))
               shutil.copyfile(filePath, newfile)
       else:  # test set
           directory = "test"
           xml_path = os.path.join(saveBasePath, directory)
           print(xml_path)
           if (not os.path.exists(xml_path)):
               os.mkdir(xml_path)
           test_num += 1
           filePath = os.path.join(xmlfilepath, name)
           newfile = os.path.join(saveBasePath, os.path.join(directory, name))
           shutil.copyfile(filePath, newfile)
    
    # End time
    end = time.time()
    seconds = end - start
    print("train total : " + str(train_num))
    print("validation total : " + str(val_num))
    print("test total : " + str(test_num))
    total_num = train_num + val_num + test_num
    print("total number : " + str(total_num))
    print("Time taken : {0} seconds".format(seconds))
    
  2. xml转csv

    import os
    import glob
    import pandas as pd
    import xml.etree.ElementTree as ET
    
    def xml_to_csv(path):
       xml_list = []
       for xml_file in glob.glob(path + '/*.xml'):
           print(xml_file)
           tree = ET.parse(xml_file)
           # print(root.find('filename').text)
    
           for member in root.findall('object'):
               value = (root.find('filename').text,
                        int(root.find('size')[0].text),  # width
                        int(root.find('size')[1].text),  # height
                        member[0].text,
                        int(member[4][0].text),
                        int(float(member[4][1].text)),
                        int(member[4][2].text),
                        int(member[4][3].text)
                        )
               xml_list.append(value)
       column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
       xml_df = pd.DataFrame(xml_list, columns=column_name)
       return xml_df
    
    def main():
       # 这里是存放转换后的三个csv的位置
       csv_root = r"/home/zheshi/h4tv/datasets/VOCdevkit/VOCBall2k/VienDataset"
       # 这个是上一步存放分割后的三个xml文件夹的路径
       annotation_root = r"/home/zheshi/h4tv/datasets/VOCdevkit/VOCBall2k/VienDataset/Annotations"
       for directory in ['train', 'test', 'validation']:
           xml_path = os.path.join(annotation_root, directory)
           xml_df = xml_to_csv(xml_path)
           xml_df.to_csv(csv_root + '/ball_{}_labels.csv'.format(directory), index=None)
           print('Successfully converted xml to csv.')
    
    main()
  3. image和label数据转为tfrecord格式: generate_tfrecord.py

    from __future__ import division
    from __future__ import print_function
    from __future__ import absolute_import
    
    import os
    import io
    import pandas as pd
    import tensorflow as tf
    
    from PIL import Image
    from utils import dataset_util
    from collections import namedtuple, OrderedDict
    
    flags = tf.app.flags
    flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
    flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
    FLAGS = flags.FLAGS
    
    # 设置要检测的类型,如果再有就加上elif按照if的格式以此累加,
    # 例如elif row_label == 'shit': return 2
    def class_text_to_int(row_label, filename):
       if row_label == 'ball':
           return 1
       else:
           print("------------------nonetype:", filename)
           return None
    
    def split(df, group):
       data = namedtuple('data', ['filename', 'object'])
       gb = df.groupby(group)
       return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
    
    def create_tf_example(group, path):
       with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
           encoded_jpg = fid.read()
       encoded_jpg_io = io.BytesIO(encoded_jpg)
       image = Image.open(encoded_jpg_io)
       width, height = image.size
    
       filename = group.filename.encode('utf8')
       image_format = b'png'
       xmins = []
       xmaxs = []
       ymins = []
       ymaxs = []
       classes_text = []
       classes = []
    
       for index, row in group.object.iterrows():
           xmins.append(row['xmin'] / width)
           xmaxs.append(row['xmax'] / width)
           ymins.append(row['ymin'] / height)
           ymaxs.append(row['ymax'] / height)
           classes_text.append(row['class'].encode('utf8'))
           classes.append(class_text_to_int(row['class'], group.filename))
    
       tf_example = tf.train.Example(features=tf.train.Features(feature={
           'image/height': dataset_util.int64_feature(height),
           'image/width': dataset_util.int64_feature(width),
           'image/filename': dataset_util.bytes_feature(filename),
           'image/source_id': dataset_util.bytes_feature(filename),
           'image/encoded': dataset_util.bytes_feature(encoded_jpg),
           'image/format': dataset_util.bytes_feature(image_format),
           'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
           'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
           'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
           'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
           'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
           'image/object/class/label': dataset_util.int64_list_feature(classes),
       }))
       return tf_example
    
    def main(_):
       writer = tf.io.TFRecordWriter(FLAGS.output_path)
       # 训练用的图片的路径
       path = '/home/zheshi/h4tv/datasets/VOCdevkit/VOCBall2k/JPEGImages'
       examples = pd.read_csv(FLAGS.csv_input)
       grouped = split(examples, 'filename')
       num = 0
       for group in grouped:
           num += 1
           tf_example = create_tf_example(group, path)
           writer.write(tf_example.SerializeToString())
           if (num % 100 == 0):  # 每完成100个转换,打印一次
               print(num)
    
       writer.close()
       output_path = os.path.join(os.getcwd(), FLAGS.output_path)
       print('Successfully created the TFRecords: {}'.format(output_path))
    
    if __name__ == '__main__':
       tf.compat.v1.app.run()

执行代码:

python generate_tfrecord.py --csv_input=data/ball_train_labels.csv --output_path=data/ball_train.tfrecord

其中csv_input是之前转换的三个csv的路径,output_path是输出的tfrecord的路径,train、test、validation需要分别运行一次。

Config

在项目中创建一个存放配置文件的目录,比如命名为vien_data,然后在其目录下创建标签分类的配置文件label_map.pbtxt,如果需要检测多个,依次往下排,id依次+1

item {
  id: 1
  name: 'ball'
}

从项目的models\research\object_detection\samples\configs\ssd_mobilenet_v1_pets.config复制一份配置的模板文件到vien_data中,我们就命名为ssd_mobilenet_v1_ball.config好了,然后修改配置文件。

如果没有预训练的model文件,配置文件中fine_tune_checkpoint要设置为空。

然后还需要修改训练集和验证集的路径,在文件末尾,修改input_path为你的训练和测试集的tfrecord路径,label_map_path为上面创建的label_map.pbtxt的路径

train_input_reader: {
  tf_record_input_reader {
    input_path: "/home/zheshi/h4tv/datasets/VOCdevkit/VOCBall2k/VienDataset/ball_train.tfrecord"
  }
  label_map_path: "/home/zheshi/tensorflow/models/research/object_detection/vien_data/ball_label_map.pbtxt"
}

eval_config: {
  metrics_set: "coco_detection_metrics"
  num_examples: 1100
}

eval_input_reader: {
  tf_record_input_reader {
    input_path: "/home/zheshi/h4tv/datasets/VOCdevkit/VOCBall2k/VienDataset/ball_validation.tfrecord"
  }
  label_map_path: "/home/zheshi/tensorflow/models/research/object_detection/vien_data/ball_label_map.pbtxt"
  shuffle: false
  num_readers: 1
}

Train

在项目的research目录下执行(其中train_dir是训练出来的结果存放的路径,pipeline_config_path是上面复制修改的配置文件ssd_mobilenet_v1_ball.config路径):

python legacy/train.py --logtostderr \
                     --train_dir=/home/zheshi/tensorflow/models/research/object_detection/vien_train_gpu_models \
                     --pipeline_config_path=/home/zheshi/tensorflow/models/research/object_detection/vien_data/ssd_mobilenet_v1_ball.config

如果遇到ModuleNotFoundError: No module named 'object_detection' ,在research目录下执行

export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
python setup.py build
python setup.py install

如果没有问题,加载配置后会开始训练,中间过程生成的文件和model都存在刚刚运行训练脚本时设置的vien_train_gpu_models目录中

可以查看图形化训练状态数据(修改logdir==training:后面的路径为执行训练脚本设置的vien_train_gpu_models目录):

 tensorboard --logdir==training:/home/zheshi/tensorflow/models/research/object_detection/vien_train_gpu_models --host=127.0.0.1

然后浏览器访问http://127.0.0.1:6006即可

viencoding.com版权所有,允许转载,但转载请注明出处和原文链接: https://viencoding.com/article/260
欢迎小伙伴们在下方评论区留言 ~ O(∩_∩)O
文章对我有帮助, 点此请博主吃包辣条 ~ O(∩_∩)O

猜你喜欢


评论

There are no comments yet.
未登录

登录后即可发表评论

登录或注册

标签

AdSense Anaconda Android apache API apt Auth AWS B-tree Bandwagon Blog bower brew bytes Caffe Catalina certbot Charles cloudcone Composer conda CoreML CPU crontab CSS csv Cuda cv2 datetime Digitalocean DNS Docker Docker-Compose Eloquent Excel export Flask FTP GET Git GitHub GitLab Gmail GoDaddy Google GTM hash Homebrew Homestead HTML http HTTPS IDEA image imagemagick imagick imgick import InnoDB ios iou iPhone ISO8601 iTerm2 Java JavaScript JPG JS Keras Laravel Laravel-Admin lazyload Linux list Livewire lnmp load logs Lravel Mac Markdown matplotlib md5 mix MobileNet Mojave mongo MongoDB MySQL Namesilo Nginx Node npm numpy Nvidia Nvidia-Docker onevps OpenCV Openpose openpyxl oss Outline parse PayPal PHP php-fpm PhpStorm PHP扩展 PIL Pillow pip PNG POST Protobuf PyCharm pyenv pymongo Python Python,人工智能,机器学习,VOC,xml Queue Redis requests RGB Sanctum save selenium SEO Shadowsock Shadowsocks ShadowsocksR simplemde Spring Boot SQLServer ssd SSH ssl SSL证书 SSR str Sublime sudo swap Swift Tensorflow TensorflowLite Terminal Terminator timestamp Ubuntu urllib UTC v2ray Valet Validation Validator VienBlog virtualenvs VPN VPS Vultr Web Windows WordPress Xcode xlsx yaml YAPI YUV zip zmq zsh 上网 下载图片 主从同步 云主机 云存储 云开发 云服务器 人工智能 代码管理 优化 优惠码 伪原创 作弊与反作弊 免费ss账号 免费提现 切片 前端 加密 协议 博客 友链 双击事件 后台运行 后端 命令 国内镜像源 图标 图片操作 图片转换 域名 多身份认证 大小写转换 姿态检测 安卓模拟器 安装 定时任务 定时执行 密码 密钥 导出导入 小程序码 延迟加载 异常 微信 微信小程序 快捷方式 慢查询 懒加载 提现 搜索引擎 搬瓦工 搭梯子 教程 数据库 数据重复 文件上传 无法登录 日志 日期 时区 时间 时间戳 服务器 机器学习 权限 梯子 模拟浏览器 港版支付宝 漏洞 爬虫 生活服务 用户管理 病毒 登录 目标检测 科学上网 系统升级 索引 组件 组件开发 编辑器 自动付款 自定义组件 英文伪原创 计划任务 计算机视觉 订阅通知 认证 语法 读写分离 远程仓库 远程连接 配置文件 重定向 错误异常 错误提示 队列 阿里云 香港 香港手机号
亲情非友情链接