基于内容的图像检索: pytorch

看了Jason Brownlee博士的Keras CBIR demo, 自己也动手用pytorch写一个.

CBIR

CBIR 为基于内容的图像检索. 用于在图像数据数据库上检索具有视觉相似性的图像. 主要启发于文本挖掘, 通过对图像提取特征, 检索时根据提取的特征来进行检索, 特征可以使用传统的sift特征描述子等, 本文使用 vae 编码的特征.

使用 mnist 数据集做测试, 编码 latent 为16维.

构建网络

构造简单的编码器-解码器

# network.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicCBR(nn.Module):
    def __init__(self, in_channel, out_channel, kernel, stride, padding):
        super(BasicCBR, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(out_channel)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x)


class Encoder(nn.Module):
    def __init__(self, img_shape=(28, 28, 1), latent_dim=16):
        super(Encoder, self).__init__()
        self.img_shape = img_shape
        self.latent_dim = latent_dim
        self.conv1 = BasicCBR(in_channel=self.img_shape[2], out_channel=32, kernel=3, stride=2, padding=3//2)
        self.conv2 = BasicCBR(32, 64, 3, 2, 3//2)
        self.latent = nn.Linear(in_features=7 * 7 * 64, out_features=latent_dim)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        _, w, h, c = x.shape
        x = torch.flatten(x, 1)
        x = self.latent(x)
        return x


class Decoder(nn.Module):
    def __init__(self, in_feat_shape=(7, 7, 64), latent=16):
        super(Decoder, self).__init__()
        self.in_feat_shape = in_feat_shape
        self.linear = nn.Linear(in_features=latent, out_features=in_feat_shape[0] * in_feat_shape[1] * in_feat_shape[2])
        self.conv1 = BasicCBR(in_channel=in_feat_shape[2], out_channel=64, kernel=3, stride=1, padding=3//2)
        self.conv2 = BasicCBR(64, 32, kernel=3, stride=1, padding=3//2)
        self.out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=3//2)

    def forward(self, x):
        x = self.linear(x)
        x = x.view(-1, self.in_feat_shape[2], self.in_feat_shape[0], self.in_feat_shape[1])
        x = self.conv1(x)
        x = F.interpolate(x, scale_factor=(2, 2))
        x = self.conv2(x)
        x = F.interpolate(x, scale_factor=(2, 2))
        x = torch.sigmoid(self.out(x))
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.encoder = Encoder(img_shape=(28, 28, 1))
        self.decoder = Decoder()

    def forward(self, x):
        latent = self.encoder(x)
        out = self.decoder(latent)
        return out

训练

训练vae先对图片进行编码, 然后使用解码器进行重构, 训练完后保留编码器即可

import torch
import torch.optim as optim
import torch.utils.data as Data
import time
import datetime
from torchvision import datasets, transforms
from network import Net


BATCH_SIZE = 32

train_dataset = datasets.MNIST(root='./mnist/', train=True, transform=transforms.ToTensor())
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

gpu = torch.cuda.is_available()
model = Net()
num_epochs = 20
epoch_size = train_dataset.__len__() // BATCH_SIZE
max_iter = epoch_size * num_epochs
loss = torch.nn.MSELoss()
if gpu:
    model = model.cuda()
    loss = loss.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)


iteration = 0
for epoch in range(num_epochs):
    model.train()
    for i, (img, label) in enumerate(train_loader):
        optimizer.zero_grad()
        load_t0 = time.time()
        if gpu:
            img = img.cuda()
        out = model(img)
        loss_v = loss(out, img)
        loss_v.backward()
        optimizer.step()
        load_t1 = time.time()
        batch_time = load_t1 - load_t0
        eta = int(batch_time * (max_iter - iteration))
        print('Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Loss: {:.4f} || Batchtime: {:.4f} s ||ETA: {}'.format(
            epoch, num_epochs, (iteration % epoch_size) + 1, epoch_size, iteration + 1, max_iter, loss_v, batch_time,
            str(datetime.timedelta(seconds=eta))))
        iteration += 1
    torch.save(model.state_dict(), "./data/model.pth")

生成特征字典

对训练集的图片进行编码提取特征, 并保存为字典, 索引为图片index, 值为编码器编码的特征向量

from network import Net
import torch
import pickle
import torch.utils.data as Data
from torchvision import datasets, transforms


train_dataset = datasets.MNIST(root='./mnist/', train=True, transform=transforms.ToTensor())
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=False)
model = Net()
model.load_state_dict(torch.load("./data/model.pth", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")))

features = []
for i, (img, lab) in enumerate(train_loader):
    latent = model.encoder(img)
    if i == 0:
        features = latent
    else:
        features = torch.cat((features, latent), 0)

indexes = list(range(0, train_dataset.train_data.shape[0]))
data = {"indexes": indexes, "features": features}
f = open("./data/features_dict.pickle", "wb")
f.write(pickle.dumps(data))
f.close()

检索相似图片

加载测试图片, 使用编码器进行编码, 然后与训练编码的特征进行相似度计算.

import numpy as np
import pickle
import cv2
import torch
import torch.utils.data as Data
from torchvision import datasets, transforms
from imutils import build_montages
from network import Net
torch.set_grad_enabled(False)


def euclidean(a, b):
    return np.linalg.norm(a - b)


def search(query_feature, features_dict, top_k=100):
    results = []
    for i in range(0, len(features_dict["features"])):
        d = euclidean(query_feature, features_dict["features"][i])
        results.append((d, i))
    results = sorted(results)[:top_k]
    return results


train_dataset = datasets.MNIST(root='./mnist/', train=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
test_loader = Data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)
feature_dict = pickle.loads(open("./data/features_dict.pickle", "rb").read())
model = Net()
model.eval()
model.load_state_dict(torch.load("./data/model.pth", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")))
encoder = model.encoder


for i, (img, lab) in enumerate(test_loader):
    query_feature = encoder(img)
    res = search(query_feature, feature_dict)
    images = []
    for (d, j) in res:
        image = train_dataset.train_data[j, ...]
        image = np.dstack([image] * 3)
        images.append(image)
    query = img.squeeze(0).permute(1, 2, 0)
    query = np.dstack([query] * 3) * 255
    cv2.imwrite("./result/{}_query.jpg".format(i), query)
    montage = build_montages(images, (28, 28), (10, 10))[0]
    cv2.imwrite("./result/{}_query_results.jpg".format(i), montage)
    break

结果展示

16维的编码结果, 效果还不错

ref

autoencoders-for-content-based-image-retrieval-with-keras-and-tensorflow/