内容损失与风格损失
在画风迁移与超分辨率重建以及图像修复等视觉领域, 内容损失又称感知损失, 使用的较多, 在此做个记录, 同时也记录下风格损失
内容损失与提取计算
内容损失很容易理解, 即两张图像内容之间的差异, 在深度学习中预训练好的网络作为特征提取器, 计算两张图像的对应 feature maps 的mse即可.
# tf 下通过vgg提取特征, 层次不要取太深
source_content = [vgg16_source.conv1_1,
vgg16_source.conv2_1]
target_content = [vgg16_target.conv1_1,
vgg16_target.conv2_1]
for source, target in zip(source_content, target_content):
#[NHWC]
content_loss += tf.reduce_mean(tf.square(source - target), axis=[1, 2, 3])
风格损失与提取计算
使用特征的gram矩阵来表达, gram矩阵可以看作体现了不同filter特征的相互关系,同时忽略了内容上的信息
# tf 下通过vgg提取, 层次越深越好
source_style = [vgg16_source.conv4_3,
vgg16_source.conv5_3]
target_style = [vgg16_target.conv4_3,
vgg16_target.conv5_3]
def gram_matrix(x):
batch_size, h, w, c = x.get_shape().as_list()
features = tf.reshape(x, shape=[batch_size, h*w, c])
# [h*w, c] to [c, h*w]
gram = tf.matmul(tf.matrix_transpose(features[0]), features[0]) / tf.constant(h*w*c, tf.float32)
return gram
source_gram = [gram_matrix(feature) for feature in source_style]
target_gram = [gram_matrix(feature) for feature in target_style]
style_loss = tf.zeros(shape=1, dtype=tf.float32)
for source, target in zip(source_gram, target_style):
style_loss += tf.reduce_mean(tf.square(source - target), axis=[0, 1])