Pytorch: CRNN 实践
最近开始深入OCR这块, 以前倒是训练过开源的Keras-CRNN, 但是它和原文还是不一样, 今天参照Keras-CRNN代码和CRNN论文用pytorch实现CRNN, 由于没有GPU, 自己造了100多张只包含数字的小图片来训练模型, 验证模型能否收敛
CRNN流程
在这儿不再详细谈CRNN论文了, 主要按照原文做一个流程描述:
- 输入图片要求高度为32, 使用VGG提取特征,高度32倍下采样,因为要求最后高度维度为1,宽度可以根据情况来,宽度论文中是4倍下采样
- CNN提取的特征为(batch_size, w/4, 1, 512), 挤压掉为1的维度后, 接上双向LSTM根据上下文特征进行预测
- LSTM最终输出为(batch_size, w/4, 总字符类别), 即在每一个位置都会对属于所有字符的任意一个进行概率预测
- 最后根据CTC_loss 进行计算损失
本文按照以上步骤展开
基本设置:
- 任务背景: 数字识别 0-9, 加一个blank 共11个字符
- 图片大小: 原CRNN中为 100 * 32 (宽 * 高), 本次实验环境下大小为 200 * 32
- 识别字符的最长长度: 本次实验设置的最长长度为 20
- 网络输出为 T 50(输入LSTM的数据的时间步, CNN 部分输出序列长度) * 11 (一共11个不同的字符, 有多少字符此处数字为多少)
网络构造
为了使用预训练的VGG权重, VGG backbone参照pytorch的VGG构造实现, 不然加载不了权重, 按照论文,第三层和第四层池化层核大小核步长改为(1, 2)
import torch.nn as nn
import torch.nn.functional as F
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
self.features = make_layers(cfgs['D'])
def forward(self, x):
x = self.features(x)
return x
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
i = 0
for v in cfg:
if v == 'M':
if i not in [9, 13]:
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
i += 1
return nn.Sequential(*layers)
cfgs = {
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M']
}
class BidirectionalLSTM(nn.Module):
def __init__(self, inp, nHidden, oup):
super(BidirectionalLSTM, self).__init__()
self.rnn = nn.LSTM(inp, nHidden, bidirectional=True)
self.embedding = nn.Linear(nHidden * 2, oup)
def forward(self, x):
out, _ = self.rnn(x)
T, b, h = out.size()
t_rec = out.view(T * b, h)
output = self.embedding(t_rec)
output = output.view(T, b, -1)
return output
class CRNN(nn.Module):
def __init__(self, characters_classes, hidden=256, pretrain=True):
super(CRNN, self).__init__()
self.characters_class = characters_classes
self.body = VGG()
# 将VGG stage5-1 卷积单独拿出来, 改了卷积核无法加载预训练参数
self.stage5 = nn.Conv2d(512, 512, kernel_size=(3, 2), padding=(1, 0))
self.hidden = hidden
self.rnn = nn.Sequential(BidirectionalLSTM(512, self.hidden, self.hidden),
BidirectionalLSTM(self.hidden, self.hidden, self.characters_class))
self.pretrain = pretrain
if self.pretrain:
import torchvision.models.vgg as vgg
pre_net = vgg.vgg16(pretrained=True)
pretrained_dict = pre_net.state_dict()
model_dict = self.body.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.body.load_state_dict(model_dict)
for param in self.body.parameters():
param.requires_grad = False
def forward(self, x):
x = self.body(x)
x = self.stage5(x)
# 挤压掉高所在的维度
x = x.squeeze(3)
# 转换为LSTM所需格式
x = x.permute(2, 0, 1).contiguous()
x = self.rnn(x)
x = F.log_softmax(x, dim=2)
return x
数据加载
数据集格式参照 MJSynth 数据集格式
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
class RegDataSet(Dataset):
def __init__(self, dataset_root, anno_txt_path, lexicon_path, target_size=(200, 32), characters="'-' + '0123456789'", transform=None):
super(RegDataSet, self).__init__()
self.dataset_root = dataset_root
self.anno_txt_path = anno_txt_path
self.lexicon_path = lexicon_path
self.target_size = target_size
self.height = self.target_size[1]
self.width = self.target_size[0]
self.characters = characters
self.imgs = []
self.lexicons = []
self.parse_txt()
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, item):
img_path, lexicon_index = self.imgs[item].split()
lexicon = self.lexicons[int(lexicon_index)].strip()
img = cv2.imread(os.path.join(self.dataset_root, img_path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_size = img.shape
if (img_size[1] / (img_size[0] * 1.0)) < 6.4:
img_reshape = cv2.resize(img, (int(32.0 / img_size[0] * img_size[1]), self.height))
mat_ori = np.zeros((self.height, self.width - int(32.0 / img_size[0] * img_size[1]), 3), dtype=np.uint8)
out_img = np.concatenate([img_reshape, mat_ori], axis=1).transpose([1, 0, 2])
else:
out_img = cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_CUBIC)
out_img = np.asarray(out_img).transpose([1, 0, 2])
label = [self.characters.find(c) for c in lexicon]
if self.transform:
out_img = self.transform(out_img)
return out_img, label
def parse_txt(self):
self.imgs = open(os.path.join(self.dataset_root, self.anno_txt_path), 'r').readlines()
self.lexicons = open(os.path.join(self.dataset_root, self.lexicon_path), 'r').readlines()
CTC LOSS
characters = "-0123456789"
ctc_loss = CTCLoss(blank=0, reduction='mean')
- blank: 占位符 ‘-’ 所在的索引, 上例中为0
- reduction: 处理loss的方式
损失计算:
ctc_loss(log_probs, targets, input_lengths, target_lengths)
- log_probs: 网络输出的tensor, shape为 (T, N, C), T 为时间步, N 为batch_size, C 为字符总数(包括blank). 本例中,假如batch_size=8,网络输出为 (50,8,11)。网络输出需要进行log_softmax
- targets: 目标tensor, targets有两种输入形式。其一: shape为 (N,S),N为batch_size,S 为识别序列的最长长度,值为每一个字符的index,不能包含blank的index。由于可能每个序列的长度不一样,而数组必须维度一样,就需要将短的序列padded 为最长序列长度(不过怎么padded没太弄明白,TensorFlow CTC_Loss 里面使用blank去填充, 但是这儿说了不能包含blank,有点迷糊, 还是用第二种吧)。 其二: 将该batch_size 内每一张图片的字符的index拼成一个一维数组. 会按照target_lengths 中的值自动对该一维数组中的index进行划分到对应图片
- target_lengths: shape 为(N) 的Tensor, 每一个位置记录了对应图片所含有的字符数. 假如 N=4,即共有4张图片,每一张图片中包含的字符个数分别为: 8, 10, 12, 20, 那么 target_lengths = (8, 10, 12, 20), 同时targets 中共有 (8 + 10 + 12 + 20)个值,按照target_lengths中的值依次在targets 中取值即可
- input_lengths: shape 为 (N) 的Tensor, 值为输出序列长度T, 因为图片宽度都固定了,所以都为T
个人实现代码如下:
def custom_collate_fn(batch, T=50):
items = list(zip(*batch))
items[0] = default_collate(items[0])
labels = list(items[1])
items[1] = []
target_lengths = torch.zeros((len(batch,)), dtype=torch.int)
input_lengths = torch.zeros(len(batch,), dtype=torch.int)
for idx, label in enumerate(labels):
# 记录每个图片对应的字符总数
target_lengths[idx] = len(label)
# 将batch内的label拼成一个list
items[1].extend(label)
# input_lengths 恒为 T
input_lengths[idx] = T
return items[0], torch.tensor(items[1]), target_lengths, input_lengths
batch_iterator = iter(DataLoader(trainSet, args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=custom_collate_fn))
images, labels, target_lengths, input_lengths = next(batch_iterator)
out = net(images)
loss = ctc_loss(log_probs=out, targets=labels, target_lengths=target_lengths, input_lengths=input_lengths)
解码
网络输出为 (50, 11), 解码先取每个位置的最大概率的字符index, index转str时,如果两个相同的index连续,那么合并为一个
例: 假设输出为: [1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 3, 0, 7, 7, 7, 0, 3, 3] , 由于后边全为0,只取前18位. 0 对应的字符是 ‘-’, 对于相邻的非0字符, 看做一个字符, 因此该例子为 [1,0,0,1,0,0,2,0,3,0,7,0,3], 再将0对应的blank 去掉, 则为实际的字符index为 [1,1,2,3,7,3]
def decode_out(str_index, characters):
char_list = []
for i in range(len(str_index)):
if str_index[i] != 0 and (not (i > 0 and str_index[i - 1] == str_index[i])):
char_list.append(characters[str_index[i]])
return ''.join(char_list)
net_out = net(img)
_, preds = net_out.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
lab2str = decode_out(preds, args.characters)