libtorch c++复现cycle gan网络
创始人
2024-05-09 11:37:15
0

目录

1. 原论文论文:https://arxiv.org/abs/1703.10593

2. 代码

2.1 下采样

2.2 残差块

 2.3 上采样模块

2.4 生成器代码

3. 判别器

3.1 判别器组件

 3. 2 判别器

 4. 训练

4.1 输入数据 

4.2 生成器loss函数结构图 

 4.3 判别器loss结构图


1. 原论文
论文:https://arxiv.org/abs/1703.10593

pytorch源码:GitHub - junyanz/pytorch-CycleGAN-and-pix2pix: Image-to-Image Translation in PyTorch

论文框架:

(1)输入领域A图片real_A,经过生成网络G_AB,生成领域B图片fake_B;

(2)fake_B再输入G_BA生成网络,生成real_A,即G_BA(G_AB(real_A)) = real_A;

(3)reconstructed image 和 输入图片real_A直接求loss,得到生成器损失;

(4)fake_B和real_B之间求生成器loss。

下面将结合代码,深入理解整个过程。

2. 代码

这里参考pytorch版本:GitHub - eriklindernoren/PyTorch-GAN: PyTorch implementations of Generative Adversarial Networks.

 实现libtorch版本。

其中,生成器G_AB 和 G_BA是同一个网络,框架细节如下。

 是一个先下采样,再接残差块,再上采样的全卷积网络。

2.1 下采样

 下采样模块是由conv2d+InstanceNorm2d+Relu组成,其中conv2d使其scale/2,channels/2.

// Down sampling : 通过conv2d进行两次下采样,同时double channels
class DownSampleImpl : public torch::nn::Module {
public:DownSampleImpl(int in_channels, int out_channels);torch::Tensor forward(torch::Tensor x);
private:torch::nn::Conv2d conv1{ nullptr };torch::nn::InstanceNorm2d bn1{ nullptr };torch::nn::ReLU relu1{ nullptr };
};
TORCH_MODULE(DownSample);
DownSampleImpl::DownSampleImpl(int in_channels, int out_channels) {conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 3).stride(2).padding(1));bn1 = torch::nn::InstanceNorm2d(out_channels);relu1 = torch::nn::ReLU(true);register_module("generator downsample pad1", conv1);register_module("generator downsample bn1", bn1);register_module("generator downsample relu1", relu1);
}
torch::Tensor DownSampleImpl::forward(torch::Tensor x) {x = conv1(x);x = bn1(x);x = relu1(x);return x;
}

2.2 残差块

每个残差块由 conv2d+InstanceNorm2d+Relu,再接conv2d+InstanceNorm2d组成。

输入到残差块的特征图shape: (b,3,256,256);

输出特征图的shape: (b,3,256,256). 即不改变维度。

// two conv2d+bn+relu. keep feature scale.
class ResidualBlockImpl : public torch::nn::Module {
public:ResidualBlockImpl(int in_channels);torch::Tensor forward(torch::Tensor x);
private:torch::nn::ReflectionPad2d pad1{ nullptr };torch::nn::Conv2d conv1{ nullptr };torch::nn::InstanceNorm2d bn1{ nullptr };torch::nn::ReLU relu1{ nullptr };torch::nn::ReflectionPad2d pad2{ nullptr };torch::nn::Conv2d conv2{ nullptr };torch::nn::InstanceNorm2d bn2{ nullptr };
};
TORCH_MODULE(ResidualBlock);
ResidualBlockImpl::ResidualBlockImpl(int in_channels) {pad1 = torch::nn::ReflectionPad2d(1);conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, in_channels, 3));bn1 = torch::nn::InstanceNorm2d(in_channels);relu1 = torch::nn::ReLU(true);pad2 = torch::nn::ReflectionPad2d(1);conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, in_channels, 3));bn2 = torch::nn::InstanceNorm2d(in_channels);register_module("block pad1", pad1);register_module("block conv1", conv1);register_module("block bn1", bn1);register_module("block pad2", pad2);register_module("block conv2", conv2);register_module("block bn2", bn2);
}
torch::Tensor ResidualBlockImpl::forward(torch::Tensor x) {x = pad1(x);x = conv1(x);x = bn1(x);x = relu1(x);x = pad2(x);x = conv2(x);x = bn2(x);return x;
}

 2.3 上采样模块

上采样模块由UpSample+Conv2d+InstanceNorm2d+ReLU组成。

用到两次上采样模块,维度变化(b,256,64,64)->(b,128,128,128)->(b,64,256,256)

/// 
/// 两次上采样,(b,256,64,64)->(b,128,128,128)->(b,64,256,256)
/// 
class UpSampleBlockImpl : public torch::nn::Module {
public:UpSampleBlockImpl(int in_channels, int out_channels);torch::Tensor forward(torch::Tensor x);
private:torch::nn::Upsample up{ nullptr };torch::nn::Conv2d conv{ nullptr };torch::nn::InstanceNorm2d bn{ nullptr };torch::nn::ReLU relu{ nullptr };
};
TORCH_MODULE(UpSampleBlock);
UpSampleBlockImpl::UpSampleBlockImpl(int in_channels, int out_channels) {up = torch::nn::Upsample(upsample_options(std::vector({2, 2})));  conv = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 3).padding(1));bn = torch::nn::InstanceNorm2d(out_channels);relu = torch::nn::ReLU(true);register_module("generator UpSampleBlock upsample", up);register_module("generator UpSampleBlock conv", conv);register_module("generator UpSampleBlock bn", bn);register_module("generator UpSampleBlock relu", relu);
}
torch::Tensor UpSampleBlockImpl::forward(torch::Tensor x) {x = up(x);x = conv(x);x = bn(x);x = relu(x);return x;
}

 最后再接一个conv2d,将通道数变成3即可输出生成的图像。

2.4 生成器代码

 

可以直接看forward函数,有5个步骤

(1)先是一个conv+bn+relu,预处理模块, size: (b,3,256,256) ->(b,64,256,256);

(2)然后是两次下采样,提取特征,size: (b,64,256,256) - > (b,128,128,128) -> (b,256,64,64);

(3)再接多个残差块,提取特征, size: (b,256,64,64) -> (b,256,64,64);

(4)上采样,size: (b,256,64,64)->(b,128,128,128)->(b,64,256,256);

(5)最后接一个输出层,即conv2d+bn+relu,size: (b,64,256,256) -> (b,3,256,256);

/// 
/// 下采样,res_blocks,上采样,output layer.
/// 
class GeneratorResNetImpl : public torch::nn::Module {
public:GeneratorResNetImpl(std::vector input_shape, int num_residual_blocks);torch::Tensor forward(torch::Tensor x);
private:torch::nn::Sequential _make_layer(int in_channels, int blocks);torch::nn::ReflectionPad2d pad1{ nullptr };torch::nn::Conv2d conv1{ nullptr };torch::nn::InstanceNorm2d bn1{ nullptr };torch::nn::ReLU relu1{ nullptr };// down DownSample down1{ nullptr };DownSample down2{ nullptr };// restorch::nn::Sequential res_blocks = torch::nn::Sequential();// upUpSampleBlock up1{ nullptr };UpSampleBlock up2{ nullptr };// output layertorch::nn::ReflectionPad2d pad2{ nullptr };torch::nn::Conv2d conv2{ nullptr };torch::nn::Tanh tanh2{ nullptr };
};
TORCH_MODULE(GeneratorResNet);torch::nn::Sequential GeneratorResNetImpl::_make_layer(int in_channels, int blocks)
{torch::nn::Sequential layers;for (int i = 0; i < blocks; i++) {layers->push_back(ResidualBlock(in_channels));}return layers;
}
GeneratorResNetImpl::GeneratorResNetImpl(std::vector input_shape, int num_residual_blocks)
{int channels = input_shape[0];  // 3int out_channels = 64;// 1, conv+bn+relu. (256+6-7+2*0)/1+1 = 256pad1 = torch::nn::ReflectionPad2d(channels);conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(channels, out_channels, 7));bn1 = torch::nn::InstanceNorm2d(out_channels);relu1 = torch::nn::ReLU(true);int in_channels = out_channels;// 2, Down sampling: 通过conv2d两次下采样,并且double channelsdown1 = DownSample(in_channels, out_channels*2);down2 = DownSample(out_channels * 2, out_channels*4);in_channels = out_channels * 4;  // 256 = 64*4// 3, Residual blocks: keep feature scale and channel unchange.res_blocks = _make_layer(in_channels, num_residual_blocks);  // (b,256,64,64)// 4, Up sampling: up+conv+bn+relu. halve channels and keep feature scale unchange.up1 = UpSampleBlock(in_channels, in_channels/2);  // (b,128,128,128)up2 = UpSampleBlock(in_channels / 2, in_channels / 4);  // (b,64,256,256)in_channels = in_channels / 4;  // 64// 5, output layer: pad+conv+tanh. change channels and keep feature scale unchange.pad2 = torch::nn::ReflectionPad2d(channels);  // 3conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, channels, 7));  // (b,64,256,256)->(b,3,256,256)tanh2 = torch::nn::Tanh();register_module("generator pad1", pad1);register_module("generator conv1", conv1);  // 一定要注册,不然不会使用cudaregister_module("generator bn1", bn1);register_module("generator relu1", relu1);register_module("generator down1", down1);register_module("generator down2", down2);register_module("generator res_blocks", res_blocks);register_module("generator up1", up1);register_module("generator up2", up2);register_module("generator pad2", pad2);register_module("generator conv2", conv2);register_module("generator tanh2", tanh2);
}
torch::Tensor GeneratorResNetImpl::forward(torch::Tensor x) {  // (b,3,256,256)// 1, conv+bn+relu. (256+6-7+2*0)/1+1 = 256x = pad1(x);x = conv1(x);x = bn1(x);x = relu1(x);   // (b,64,256,256)// 2, Down sampling: 通过conv2d两次下采样,并且double channelsx = down1(x);  // (b,128,128,128)x = down2(x);  // (b,256,64,64)// 3, Residual blocks: keep feature scale and channel unchange.x = res_blocks->forward(x);  // (b,256,64,64)// 4, Up sampling: up+conv+bn+relu. halve channels and keep feature scale unchange.x = up1(x);  // (b,128,128,128)x = up2(x);  // (b,64,256,256)// 5, output layer: pad+conv+tanh. change channels and keep feature scale unchange.x = pad2(x);x = conv2(x);x = tanh2(x);  // (b, 3, 256, 256)std::cout << x.sizes() << std::endl;return x;
}

3. 判别器

输入的是生成图图片(b,3,256,256),经过5次卷积,输出的是判别分数(b,1,16,16).

3.1 判别器组件

判别器组件是由conv2d+InstanceNorm2d+relu组成. 改变通道和scale.

/// 
/// Conv2d + bn + relu
/// 其中kernel_size设置成4,跟patchGan有关。
/// 
class DiscriminatorBlockImpl : public torch::nn::Module {
public:DiscriminatorBlockImpl(int in_channels, int out_channels, bool normalize = true);torch::Tensor forward(torch::Tensor x);
private:bool normalize = true;torch::nn::Conv2d conv{ nullptr };torch::nn::InstanceNorm2d bn{ nullptr };torch::nn::LeakyReLU relu{ nullptr };
};
TORCH_MODULE(DiscriminatorBlock);
DiscriminatorBlockImpl::DiscriminatorBlockImpl(int in_channels, int out_channels, bool normalize) {this->normalize = normalize;conv = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 4).stride(2).padding(1));if (normalize) bn = torch::nn::InstanceNorm2d(out_channels);relu = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));register_module("DiscriminatorBlock conv", conv);if (normalize) register_module("DiscriminatorBlock bn", bn);register_module("DiscriminatorBlock relu", relu);
}
torch::Tensor DiscriminatorBlockImpl::forward(torch::Tensor x) {x = conv(x);if (this->normalize)x = bn(x);x = relu(x);return x;
}

 3. 2 判别器

// (b,3,256,256)->(b,512,16,16)
torch::nn::Sequential _make_discriminator_blocks(int in_channels, int out_channels) {torch::nn::Sequential layers;layers->push_back(DiscriminatorBlock(in_channels, out_channels, false));layers->push_back(DiscriminatorBlock(out_channels, out_channels*2, true));layers->push_back(DiscriminatorBlock(out_channels * 2, out_channels * 4, true));layers->push_back(DiscriminatorBlock(out_channels * 4, out_channels * 8, true));return layers;
}
class DiscriminatorImpl : public torch::nn::Module {
public:DiscriminatorImpl(std::vector input_shape);torch::Tensor forward(torch::Tensor x);
public:std::vector output_shape_hw;//std::vector output_shape;
private:torch::nn::Sequential discriminator_blocks{ nullptr };torch::nn::ZeroPad2d pad{ nullptr };torch::nn::Conv2d conv{ nullptr };
};
TORCH_MODULE(Discriminator);
DiscriminatorImpl::DiscriminatorImpl(std::vector input_shape) {int channels = input_shape[0], height = input_shape[1], width = input_shape[2];// Calculate output shape of image discriminator (PatchGAN)this->output_shape_hw = { 1, height / int(pow(2,4)), width / int(pow(2,4)) };  // 外部调用,//this->output_shape = std::vector({ 1, height / int(pow(2,4)), width / int(pow(2,4)) });// 1, dis blocksdiscriminator_blocks = _make_discriminator_blocks(channels, 64);  // (b,512,16,16)// 2, zeropadpad = torch::nn::ZeroPad2d(torch::nn::ZeroPad2dOptions({ 1, 0, 1, 0 }));  // left,right,up,down// 3, convconv = torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 1, 4).padding(1));register_module("Discriminator discriminator_blocks", discriminator_blocks);register_module("Discriminator pad", pad);register_module("Discriminator conv", conv);
}
torch::Tensor DiscriminatorImpl::forward(torch::Tensor x) {  // (b,3,256,256)x = discriminator_blocks->forward(x);  // (b,3,256,256)->(b,512,16,16)x = pad(x);  // (b,512,17,17)x = conv(x);  // (b,1,16,16)std::cout << x.sizes() << std::endl;return x;
}

 4. 训练

4.1 输入数据 

real_A和real_B分别是领域A和领域B图片,valid和fake分别是全1和全0矩阵。

real_A和real_B size: (b,3,256,256);

valid和fake size: (b,1,16,16).

// Set model input:
torch::Tensor real_A = batch.data.toType(torch::kF32).to(torch::kCUDA);  // (b,3,256,256)
torch::Tensor real_B = batch.target.toType(torch::kF32).to(torch::kCUDA);  // (b,3,256,256)
torch::Tensor valid = torch::ones({ real_A.size(0), D_A->output_shape_hw.at(0), D_A->output_shape_hw.at(1), D_A->output_shape_hw.at(2) }, torch::kF32).to(torch::kCUDA);  // (32,1,16,16). 
torch::Tensor fake = torch::zeros({ real_A.size(0), D_A->output_shape_hw.at(0), D_A->output_shape_hw.at(1), D_A->output_shape_hw.at(2) }, torch::kF32).to(torch::kCUDA);  // (32,1,16,16). 

4.2 生成器loss函数结构图 

/*
----------------------Train Generators
----------------------
*/
// 1, Identity loss: cycGan可加可不加,加上identity loss生成的效果更好。
// 生成器G用来生成y风格图像,那么把y送入G,应该仍然生成y,G(y) = y,只有这样才能保证具有生成y风格的能力。
// 如果不加该loss,那么生成器可能会自主地修改图像的色调,使得整体的颜色产生变化。
torch::Tensor loss_id_A = l1_loss_identity(G_BA(real_A), real_A);  // G_BA(A) = A, 保证生成的A接近A
torch::Tensor loss_id_B = l1_loss_identity(G_AB(real_B), real_B);  // G_AB(B) = B, 保证生成的B接近B
torch::Tensor loss_identity = (loss_id_A + loss_id_B) / 2;// 2, Gan loss: 让生成的图像更能称之为图像,也就是生成的图像更真实。但它不保证能生成到我们想要的图像。
torch::Tensor fake_B = G_AB(real_A);
torch::Tensor loss_GAN_AB = mse_loss_gan(D_B(fake_B), valid);  // 由A生成B, D_B分数越高越好,D_AB(G_AB(A)) = 1
torch::Tensor fake_A = G_BA(real_B);
torch::Tensor loss_GAN_BA = mse_loss_gan(D_A(fake_A), valid);  // # 由B生成A, D_A分数越高越好,D_BA(G_BA(B)) = 1
torch::Tensor loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2;// 3, Cycle loss: 保证生成器的输出图片与输入图片只是风格不同,而内容相同
torch::Tensor loss_cycle_A = l1_loss_cycle(G_BA(fake_B.detach()), real_A);  // G_BA(G_AB(A)) = A
torch::Tensor loss_cycle_B = l1_loss_cycle(G_AB(fake_A), real_B);  // G_BA(G_AB(A)) = A
torch::Tensor loss_cycle = (loss_cycle_A + loss_cycle_B) / 2;// total g loss: loss_gan + 10*loss_cycle + 5*loss_identity
torch::Tensor loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity;
loss_G.backward();

 4.3 判别器loss结构图


待续。。。

相关内容

热门资讯

安卓系统计划软件推荐,精选计划... 你有没有发现,手机里的安卓系统越来越智能了?这不,最近我可是挖到了一些超棒的安卓计划软件,它们不仅能...
收钱吧安卓系统插件,便捷支付新... 你有没有发现,现在的生活越来越离不开手机了?手机里装满了各种应用,而今天我要跟你聊聊一个特别实用的工...
鸿蒙系统是否还属于安卓,独立于... 你有没有想过,那个在我们手机上默默无闻的鸿蒙系统,它到底是不是安卓的“亲戚”呢?这个问题,估计不少手...
安卓系统手机用什么钱包,轻松管... 你有没有想过,你的安卓系统手机里装了那么多应用,但最离不开的,可能就是那个小小的钱包了。没错,就是那...
安卓系统能玩部落冲突吗,部落冲... 你有没有想过,安卓系统上的手机,是不是也能玩那款风靡全球的《部落冲突》呢?这款游戏自从推出以来,就吸...
智能机器人安卓系统,引领未来智... 你知道吗?在科技飞速发展的今天,智能机器人已经不再是科幻电影里的专属了。它们正悄悄地走进我们的生活,...
华为win10系统改装安卓系统... 你有没有想过,你的华为笔记本电脑里的Windows 10系统,能不能来个华丽变身,变成安卓系统呢?这...
旧电脑上安什么安卓系统,适配不... 你那台旧电脑是不是已经闲置好久了?别让它默默无闻地躺在角落里,给它来个华丽变身吧!今天,就让我来告诉...
安卓app语言跟随系统,随系统... 你知道吗?在手机世界里,有一个神奇的小功能,它就像你的贴身翻译官,无论你走到哪里,都能帮你轻松应对各...
惠城安卓系统降级在哪,揭秘降级... 你有没有遇到过手机系统升级后,发现新系统让你头疼不已,想回到那个熟悉的安卓系统呢?别急,今天就来告诉...
阿里云系统转安卓,揭秘安卓平台... 你知道吗?最近有个大动作在互联网圈里引起了不小的波澜,那就是阿里云系统竟然要转战安卓阵营了!这可不是...
安卓系统有最美壁纸么,探寻最美... 哦,亲爱的安卓用户,你是否曾在某个午后,百无聊赖地翻看着手机,突然被那一张张壁纸惊艳了眼眸?是的,我...
安卓系统采用Linux操作系统... 你知道吗?安卓系统,这个在我们手机上无处不在的小家伙,它的心脏竟然是Linux操作系统内核!是不是觉...
安卓原生平板通用系统,探索安卓... 你有没有发现,现在市面上平板电脑的品牌和型号真是五花八门,让人挑花了眼?不过,你知道吗?在众多安卓平...
小米1系统是安卓几,搭载安卓几... 你有没有想过,你的小米手机里那个熟悉的系统,其实是基于安卓的哦!没错,就是那个全球最流行的手机操作系...
可以安装安卓系统的相机,智能摄... 你有没有想过,一台相机不仅能拍出美美的照片,还能像智能手机一样,玩转各种应用?没错,现在市面上就有这...
安卓系统gps定位不准,安卓G... 你是不是也遇到过这种情况?手机里的安卓系统GPS定位总是不准,让人头疼不已。有时候,你明明就在家附近...
电信机顶盒装安卓系统,开启智能... 你有没有想过,家里的电信机顶盒其实也可以装上安卓系统呢?听起来是不是有点不可思议?别急,让我带你一步...
安卓系统可以做苹果桌面,打造个... 你知道吗?现在科技的发展真是让人眼花缭乱,竟然有人想出了安卓系统可以做苹果桌面的神奇想法!是不是觉得...
安卓系统自带的网页,功能与特色... 你有没有发现,每次打开安卓手机,那熟悉的系统界面里总有一个默默无闻的小家伙——安卓系统自带的网页浏览...