Faster R-CNN 笔记

Faster R-CNN的学习笔记,备忘。之后可以用Tensorflow来实现一个。主要内容

  1. Faster R-CNN结构简述
  2. 训练和实现细节
  3. PASCAL VOC数据集的处理

网络结构

Fast R-CNN速度的瓶颈在于region propose,Faster R-CNN将region propose的工作交由CNN来处理(Region Proposal Network, RPN),实现了端到端的区域推荐功能,而且RPN能够和网络共享权重,降低了训练难度。剩余的部分和Fast R-CNN一样,Faster R-CNN需要输入图像和region proposal进行分类,而且对于全图共用feature map来降低运行压力。所以Faster R-CNN中,RPN用于进行区域推荐,Fast R-CNN用于目标分类。

图像首先被等比例缩放至短边为600像素的大小,生成全图的fature map。对于RPN,在feature map上采用$n\times n$的滑窗,滑窗输出$256$维特征,每个滑窗中心对应$k$个anchor,滑窗输出连接到两个并列$1\times 1$的网络,负责输出对应的$k$个anchor是否是感兴趣的区域,并且回归出bounding box。对于分类任务,基本工作和Fast R-CNN相同。将RPN的输出和图像用于输入,进行分类。

Faster R-CNN is a single, unified network for object detection. The RPN module serves as the 'attention' of this unified network.

总之,Faster R-CNN是一个独立的统一的用于目标检测的网络。RPN模块为网络提供“注意力”的功能。

Anchor是什么?

Anchor是作者手工选取的指标,它是指以在特征图上滑窗中心为中心的,大小和比例人为选定的在原图上的矩形区域。在网络中首先选取与ground truth的标记框靠近(使用Intersection-over-Union, IoU来衡量)的anchor box,训练时,对这些正样本bounding box以anchor box为基准进行回归。具体的计算方法请参照下面的损失函数计算。对于所有的anchor box,根据他们与ground truth的IoU,对他们的正负(IoU高的算正样本,IoU低的是负样本)进行逻辑斯蒂回归。

RPN - Region Proposal Network

实际上的RPN中在特征图上的$n\times n$滑窗,连接相邻的两个全连接层用于输出分类和回归信息,这样的结构可以由$n\times n$的卷积核和$1\times 1$的卷积核自动实现。

在RPN中,anchor实际上比滑窗的感受野大。但是仍然可以用滑窗中的特征对bounding box进行回归,可以将这种操作看做是“根据火车中部猜整列火车的位置”。

训练和实现细节

RPN网络可以单独训练,损失函数可以记为

其中$L_{cls}$是log损失,$L_{reg}$使用的是Fast R-CNN中提到的robust log function。$t$是一个记录了框中心点位置和宽高参数的四维向量,$t_x$和$t_y$记录的是中心点坐标,$t_w$和$t_h$是宽高。$t$,$t_a$和$t^*$分别表示预测、anchor box和ground truth box。

对卷积层网络用ImageNet预训练参数进行初始化,自定的层用标准差为$0.01$,均值为$0$的高斯分布初始化。ZF网络对所有层进行fine-tune,VGG-16对conv3_1和以上的层进行fine-tune(为了节约内存)。对于前60000个mini-batch使用了0.001的学习率,对于后20000个mini-batch使用0.0001的学习率,动量为0.9,衰减为0.0005。

RPN和Fast-RCNN共用特征训练

单独训练RPN和Fast-RCNN会导致卷积层收敛的结果有所不同。因此需要一种技术在训练时共享卷积层。

论文中提出了“4步交替训练”方法:

  1. 首先使用在ImageNet预训练的权重初始化网络,单独训练RPN
  2. 使用ImageNet预训练权重初始化网络,和之前的RPN输出的区域推荐,训练Fast-RCNN
  3. 使用训练结束的Fast-RCNN的卷积层对RPN的卷积层初始化,并锁定,只(在之前训练好的基础上)训练RPN独有的权重,此时RPN和Fast-RCNN拥有相同的卷积层权重
  4. 固定Fast-RCNN的卷积层权重,使用上一步训练出的RPN提供的区域推荐,继续第二步训练Fast-RCNN卷积层之上的权重

实现细节

  • 所有输入图像都被缩放到短边600像素的大小
  • anchor的选取没有针对特定数据集进行挑选,而是任意指定的
  • RPN正样本的选取规则是:与任意ground truth box的IoU大于0.7或者对于某个ground truth box,与它IoU最大的那个anchor box。引入第二种情况是为了防止很小的ground truth box导致没有与它相关的anchor被标记为正
  • 有超出图像边界部分的anchor在训练时被忽略,但是在测试时,这些anchor不被忽略
  • 某些RPN推荐结果相互之间重叠,为了简化结果,采用了非极大抑制算法
  • 训练RPN时采样:对一张图随机采样256个anchor用于计算mini-batch的损失。正负样本拥有1:1的数量比。当正样本少于128个时,使用负样本填充mini-batch

PASCAL VOC数据集的处理

在ubuntu 16.04下,使用python处理。需要opencv-python的支持。首先VOC文件夹中,我们需要的标注信息仅包含图像本身和目标框标注。图像在VOC下的./JPEGImages文件夹,框标注在./Annotations文件夹中所有的xml文件里。这些xml文件还包含别的信息,但是除了图像文件名,框位置和类别,其它我们都不需要。xml解析使用了xml.sax,相关资料可以在搜索引擎上方便地找到。

能够读取标注信息之后,就可以对网络进行训练了,下面的代码用于参考,如何处理VOC数据集的annotation。更多关于VOC数据集,请参阅The PASCAL Visual Object Classes Homepage

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Process annotations in VOC dataset
# Copyright (C) 2018 njzwj
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import cv2
import os
import re
import random
import numpy as np
import pandas as pd
import xml.sax
import matplotlib.pyplot as plt
VOC_DIR = '../VOC2012' # VOC location, you should configure it before using
# generate image file list
# images of pascal voc is in ./JPEGImages folder
# annotations are in ./Annotations folder
JPEG_DIR = os.path.join(VOC_DIR, 'JPEGImages')
ANNOTATION_DIR = os.path.join(VOC_DIR, 'Annotations')
annotation_list = os.listdir(ANNOTATION_DIR)
# process annotations
# image annotation for faster rcnn
# - filename
# - size (width, height)
# - object
# - name
# - bndbox (xmin, ymin, xmax, ymax)
class VocXmlHandler( xml.sax.ContentHandler ):
def __init__(self):
self.CurrentData = ''
self.CurrentObjName = ''
self.CurrentObjBndbox = {}
self.filename = ''
self.size = [0, 0]
self.objects = []
self.stack = []
self.data = {}
def startElement(self, tag, attributes):
self.CurrentData = tag
self.stack.append(tag)
def endElement(self, tag):
if tag == 'object':
self.objects.append({
'name': self.CurrentObjName,
'bndbox': (
self.CurrentObjBndbox['xmin'],
self.CurrentObjBndbox['ymin'],
self.CurrentObjBndbox['xmax'],
self.CurrentObjBndbox['ymax']
)
})
elif tag == 'annotation':
self.data = {
'filename': self.filename,
'size': tuple(self.size),
'object': self.objects
}
self.stack.pop()
def characters(self, content):
content = re.sub('[\r\n\t\ ]', '', content)
if content == '':
return
if self.CurrentData == 'filename':
self.filename = content
elif self.CurrentData == 'width':
self.size[0] = int(content)
elif self.CurrentData == 'height':
self.size[1] = int(content)
elif self.CurrentData == 'name':
self.CurrentObjName = content
elif self.CurrentData in ['xmin', 'xmax', 'ymin', 'ymax'] and \
self.stack[-3] == 'object':
self.CurrentObjBndbox[self.CurrentData] = int(content)
def viewAnnotation( xmlDir ):
parser = xml.sax.make_parser()
parser.setFeature(xml.sax.handler.feature_namespaces, 0)
Handler = VocXmlHandler()
parser.setContentHandler( Handler )
parser.parse( xmlDir )
img = cv2.imread( os.path.join(JPEG_DIR, Handler.data['filename']) )
img1 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for _ in Handler.data['object']:
bndbox = _['bndbox']
cv2.rectangle(img1, (bndbox[0], bndbox[1]), \
(bndbox[2], bndbox[3]), (255, 0, 0), 3)
plt.imshow(img1)
plt.show()
for _ in range(10):
idx = random.randint(0, 1000)
viewAnnotation( os.path.join(ANNOTATION_DIR, annotation_list[idx]) )

请参阅

  1. The PASCAL Visual Object Classes Homepage
  2. Ren S, Girshick R, Girshick R, et al. Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks[J]. IEEE Transactions on Pattern Analysis & Machine Intelligence, 2017, 39(6):1137.
  3. Girshick R. Fast R-CNN[J]. Computer Science, 2015.