基于内容的图像检索: pytorch
看了Jason Brownlee博士的Keras CBIR demo, 自己也动手用pytorch写一个.
CBIR
CBIR 为基于内容的图像检索. 用于在图像数据数据库上检索具有视觉相似性的图像. 主要启发于文本挖掘, 通过对图像提取特征, 检索时根据提取的特征来进行检索, 特征可以使用传统的sift特征描述子等, 本文使用 vae 编码的特征.
使用 mnist 数据集做测试, 编码 latent 为16维.
构建网络
构造简单的编码器-解码器
# network.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicCBR(nn.Module):
def __init__(self, in_channel, out_channel, kernel, stride, padding):
super(BasicCBR, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel, stride=stride, padding=padding)
self.bn = nn.BatchNorm2d(out_channel)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x)
class Encoder(nn.Module):
def __init__(self, img_shape=(28, 28, 1), latent_dim=16):
super(Encoder, self).__init__()
self.img_shape = img_shape
self.latent_dim = latent_dim
self.conv1 = BasicCBR(in_channel=self.img_shape[2], out_channel=32, kernel=3, stride=2, padding=3//2)
self.conv2 = BasicCBR(32, 64, 3, 2, 3//2)
self.latent = nn.Linear(in_features=7 * 7 * 64, out_features=latent_dim)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
_, w, h, c = x.shape
x = torch.flatten(x, 1)
x = self.latent(x)
return x
class Decoder(nn.Module):
def __init__(self, in_feat_shape=(7, 7, 64), latent=16):
super(Decoder, self).__init__()
self.in_feat_shape = in_feat_shape
self.linear = nn.Linear(in_features=latent, out_features=in_feat_shape[0] * in_feat_shape[1] * in_feat_shape[2])
self.conv1 = BasicCBR(in_channel=in_feat_shape[2], out_channel=64, kernel=3, stride=1, padding=3//2)
self.conv2 = BasicCBR(64, 32, kernel=3, stride=1, padding=3//2)
self.out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=3//2)
def forward(self, x):
x = self.linear(x)
x = x.view(-1, self.in_feat_shape[2], self.in_feat_shape[0], self.in_feat_shape[1])
x = self.conv1(x)
x = F.interpolate(x, scale_factor=(2, 2))
x = self.conv2(x)
x = F.interpolate(x, scale_factor=(2, 2))
x = torch.sigmoid(self.out(x))
return x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.encoder = Encoder(img_shape=(28, 28, 1))
self.decoder = Decoder()
def forward(self, x):
latent = self.encoder(x)
out = self.decoder(latent)
return out
训练
训练vae先对图片进行编码, 然后使用解码器进行重构, 训练完后保留编码器即可
import torch
import torch.optim as optim
import torch.utils.data as Data
import time
import datetime
from torchvision import datasets, transforms
from network import Net
BATCH_SIZE = 32
train_dataset = datasets.MNIST(root='./mnist/', train=True, transform=transforms.ToTensor())
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
gpu = torch.cuda.is_available()
model = Net()
num_epochs = 20
epoch_size = train_dataset.__len__() // BATCH_SIZE
max_iter = epoch_size * num_epochs
loss = torch.nn.MSELoss()
if gpu:
model = model.cuda()
loss = loss.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
iteration = 0
for epoch in range(num_epochs):
model.train()
for i, (img, label) in enumerate(train_loader):
optimizer.zero_grad()
load_t0 = time.time()
if gpu:
img = img.cuda()
out = model(img)
loss_v = loss(out, img)
loss_v.backward()
optimizer.step()
load_t1 = time.time()
batch_time = load_t1 - load_t0
eta = int(batch_time * (max_iter - iteration))
print('Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || Loss: {:.4f} || Batchtime: {:.4f} s ||ETA: {}'.format(
epoch, num_epochs, (iteration % epoch_size) + 1, epoch_size, iteration + 1, max_iter, loss_v, batch_time,
str(datetime.timedelta(seconds=eta))))
iteration += 1
torch.save(model.state_dict(), "./data/model.pth")
生成特征字典
对训练集的图片进行编码提取特征, 并保存为字典, 索引为图片index, 值为编码器编码的特征向量
from network import Net
import torch
import pickle
import torch.utils.data as Data
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(root='./mnist/', train=True, transform=transforms.ToTensor())
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=False)
model = Net()
model.load_state_dict(torch.load("./data/model.pth", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")))
features = []
for i, (img, lab) in enumerate(train_loader):
latent = model.encoder(img)
if i == 0:
features = latent
else:
features = torch.cat((features, latent), 0)
indexes = list(range(0, train_dataset.train_data.shape[0]))
data = {"indexes": indexes, "features": features}
f = open("./data/features_dict.pickle", "wb")
f.write(pickle.dumps(data))
f.close()
检索相似图片
加载测试图片, 使用编码器进行编码, 然后与训练编码的特征进行相似度计算.
import numpy as np
import pickle
import cv2
import torch
import torch.utils.data as Data
from torchvision import datasets, transforms
from imutils import build_montages
from network import Net
torch.set_grad_enabled(False)
def euclidean(a, b):
return np.linalg.norm(a - b)
def search(query_feature, features_dict, top_k=100):
results = []
for i in range(0, len(features_dict["features"])):
d = euclidean(query_feature, features_dict["features"][i])
results.append((d, i))
results = sorted(results)[:top_k]
return results
train_dataset = datasets.MNIST(root='./mnist/', train=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
test_loader = Data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)
feature_dict = pickle.loads(open("./data/features_dict.pickle", "rb").read())
model = Net()
model.eval()
model.load_state_dict(torch.load("./data/model.pth", map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")))
encoder = model.encoder
for i, (img, lab) in enumerate(test_loader):
query_feature = encoder(img)
res = search(query_feature, feature_dict)
images = []
for (d, j) in res:
image = train_dataset.train_data[j, ...]
image = np.dstack([image] * 3)
images.append(image)
query = img.squeeze(0).permute(1, 2, 0)
query = np.dstack([query] * 3) * 255
cv2.imwrite("./result/{}_query.jpg".format(i), query)
montage = build_montages(images, (28, 28), (10, 10))[0]
cv2.imwrite("./result/{}_query_results.jpg".format(i), montage)
break
结果展示
16维的编码结果, 效果还不错
ref
autoencoders-for-content-based-image-retrieval-with-keras-and-tensorflow/