CycleGAN论文阅读总结及实现

本站访问次数:

在cyclegan之前,对于两个域的图像进行转化,比如图像风格转换,它们的训练集图像都是成对的.而cyclegan则解决了训练图像必须成对的问题。使生成器的学习过程比image2image更像是两个图像域之间图像“翻译”。

  • 训练图片说明 下图分别是成对图像训练集与非成对图像训练集例子,成对图像训练时需要一一对应。

  • cyclegan

cyclegan的网络设计思想本身不复杂。其中包含两个生成器,一个由图像域A生成图像域B,另一个由图像域B生成图像域A。两个判别器分别针对一个图像域。训练时,随机选一张域A的图像,由域A生成域B的图像,再将生成的图像由域B转换回域A,得到重构的输入图像,形成一个cycle。输入域B的图像过程与上相同。对每个生成器而言,误差即为重构误差(不采用identity loss的情况下)。

cyclegan损失函数由对抗损失与循环一致性损失构成(作者另外加上了identity loss)

对抗损失(判别器):

循环一致性损失(生成器重构误差):

总损失:

  • 网络实现。对于生成器,采用在imagenet上预训练的vgg16作为基础。
class CycleGan:

    def __init__(self, height, weight, channels=3):
        self.height = height
        self.weight = weight
        self.channels = channels
        self.img_shape = (self.height, self.weight, self.channels)

    def build_generator(self):
        # U-net like based on vgg16
        input_img = Input(name='input_img',
                          shape=(self.height,
                                 self.weight,
                                 self.channels),
                          dtype='float32')
        vgg16 = VGG16(input_tensor=input_img,
                      weights='imagenet',
                      include_top=False)
        vgg_pools = [vgg16.get_layer('block%d_pool' % i).output
                     for i in range(1, 6)]

        def decoder(layer_input, skip_input, channel, last_block=False):
            if not last_block:
                concat = Concatenate(axis=-1)([layer_input, skip_input])
                bn1 = InstanceNormalization()(concat)
            else:
                bn1 = InstanceNormalization()(layer_input)
            conv_1 = Conv2D(channel, 1,
                            activation='relu', padding='same')(bn1)
            bn2 = InstanceNormalization()(conv_1)
            conv_2 = Conv2D(channel, 3,
                            activation='relu', padding='same')(bn2)
            return conv_2

        d1 = decoder(UpSampling2D((2, 2))(vgg_pools[4]), vgg_pools[3], 256)
        d2 = decoder(UpSampling2D((2, 2))(d1), vgg_pools[2], 128)
        d3 = decoder(UpSampling2D((2, 2))(d2), vgg_pools[1], 64)
        d4 = decoder(UpSampling2D((2, 2))(d3), vgg_pools[0], 32)
        d5 = decoder(UpSampling2D((2, 2))(d4), None, 32, True)

        output = Conv2D(3, 3, activation='tanh', padding='same')(d5)
        model = Model(inputs=input_img, outputs=output)
        # model.summary()
        return model

    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, normalization=True):
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d

        image = Input(shape=self.img_shape)

        d1 = d_layer(image, 64, normalization=False)
        d2 = d_layer(d1, 128)
        d3 = d_layer(d2, 256)
        d4 = d_layer(d3, 512)

        patch_out = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
        discriminator = Model(image, patch_out)
        optimizer = Adam(0.0002, 0.5)
        discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
        # discriminator.summary()
        return discriminator

    def cycle_gan(self, gen_a2b, gen_b2a, dis_a, dis_b):
        image_a = Input(shape=self.img_shape)
        image_b = Input(shape=self.img_shape)

        fake_b = gen_a2b(image_a)
        fake_a = gen_b2a(image_b)

        reconstr_a = gen_b2a(fake_b)
        reconstr_b = gen_a2b(fake_a)

        img_a_identity = gen_b2a(image_a)
        img_b_identity = gen_a2b(image_b)

        dis_a.trainable = False
        dis_b.trainable = False

        patch_out_a = dis_a(fake_a)
        patch_out_b = dis_b(fake_b)

        cycle_model = Model(inputs=[image_a, image_b],
                            outputs=[patch_out_a, patch_out_b,
                                     reconstr_a, reconstr_b,
                                     img_a_identity, img_b_identity])
        optimizer = Adam(0.0002, 0.5)
        lambda_cycle = 10.0  # Cycle-consistency loss
        lambda_id = 0.1 * lambda_cycle  # Identity loss
        cycle_model.compile(loss=['mse', 'mse',
                                  'mae', 'mae',
                                  'mae', 'mae'],
                            loss_weights=[1, 1,
                                          lambda_cycle, lambda_cycle,
                                          lambda_id, lambda_id],
                            optimizer=optimizer)
        # cycle_model.summary()
        return cycle_model
  • 复现结果,只复现了maps与monet2photo