内容损失与风格损失

在画风迁移与超分辨率重建以及图像修复等视觉领域, 内容损失又称感知损失, 使用的较多, 在此做个记录, 同时也记录下风格损失

内容损失与提取计算

内容损失很容易理解, 即两张图像内容之间的差异, 在深度学习中预训练好的网络作为特征提取器, 计算两张图像的对应 feature maps 的mse即可.

# tf 下通过vgg提取特征, 层次不要取太深

source_content = [vgg16_source.conv1_1,
                  vgg16_source.conv2_1]
target_content = [vgg16_target.conv1_1,
                  vgg16_target.conv2_1]

for source, target in zip(source_content, target_content):
    #[NHWC]
    content_loss += tf.reduce_mean(tf.square(source - target), axis=[1, 2, 3])

风格损失与提取计算

使用特征的gram矩阵来表达, gram矩阵可以看作体现了不同filter特征的相互关系,同时忽略了内容上的信息

# tf 下通过vgg提取, 层次越深越好
source_style = [vgg16_source.conv4_3,
                  vgg16_source.conv5_3]
target_style = [vgg16_target.conv4_3,
                  vgg16_target.conv5_3]

def gram_matrix(x):
    batch_size, h, w, c = x.get_shape().as_list()
    features = tf.reshape(x, shape=[batch_size, h*w, c])
    # [h*w, c] to [c, h*w]
    gram = tf.matmul(tf.matrix_transpose(features[0]), features[0]) / tf.constant(h*w*c, tf.float32)
    return gram
source_gram = [gram_matrix(feature) for feature in source_style]
target_gram = [gram_matrix(feature) for feature in target_style]

style_loss = tf.zeros(shape=1, dtype=tf.float32)
for source, target in zip(source_gram, target_style):
    style_loss += tf.reduce_mean(tf.square(source - target), axis=[0, 1])

推荐代码 pytorch