libtorch c++复现cycle gan网络
创始人
2024-05-09 11:37:15
0

目录

1. 原论文论文:https://arxiv.org/abs/1703.10593

2. 代码

2.1 下采样

2.2 残差块

 2.3 上采样模块

2.4 生成器代码

3. 判别器

3.1 判别器组件

 3. 2 判别器

 4. 训练

4.1 输入数据 

4.2 生成器loss函数结构图 

 4.3 判别器loss结构图


1. 原论文
论文:https://arxiv.org/abs/1703.10593

pytorch源码:GitHub - junyanz/pytorch-CycleGAN-and-pix2pix: Image-to-Image Translation in PyTorch

论文框架:

(1)输入领域A图片real_A,经过生成网络G_AB,生成领域B图片fake_B;

(2)fake_B再输入G_BA生成网络,生成real_A,即G_BA(G_AB(real_A)) = real_A;

(3)reconstructed image 和 输入图片real_A直接求loss,得到生成器损失;

(4)fake_B和real_B之间求生成器loss。

下面将结合代码,深入理解整个过程。

2. 代码

这里参考pytorch版本:GitHub - eriklindernoren/PyTorch-GAN: PyTorch implementations of Generative Adversarial Networks.

 实现libtorch版本。

其中,生成器G_AB 和 G_BA是同一个网络,框架细节如下。

 是一个先下采样,再接残差块,再上采样的全卷积网络。

2.1 下采样

 下采样模块是由conv2d+InstanceNorm2d+Relu组成,其中conv2d使其scale/2,channels/2.

// Down sampling : 通过conv2d进行两次下采样,同时double channels
class DownSampleImpl : public torch::nn::Module {
public:DownSampleImpl(int in_channels, int out_channels);torch::Tensor forward(torch::Tensor x);
private:torch::nn::Conv2d conv1{ nullptr };torch::nn::InstanceNorm2d bn1{ nullptr };torch::nn::ReLU relu1{ nullptr };
};
TORCH_MODULE(DownSample);
DownSampleImpl::DownSampleImpl(int in_channels, int out_channels) {conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 3).stride(2).padding(1));bn1 = torch::nn::InstanceNorm2d(out_channels);relu1 = torch::nn::ReLU(true);register_module("generator downsample pad1", conv1);register_module("generator downsample bn1", bn1);register_module("generator downsample relu1", relu1);
}
torch::Tensor DownSampleImpl::forward(torch::Tensor x) {x = conv1(x);x = bn1(x);x = relu1(x);return x;
}

2.2 残差块

每个残差块由 conv2d+InstanceNorm2d+Relu,再接conv2d+InstanceNorm2d组成。

输入到残差块的特征图shape: (b,3,256,256);

输出特征图的shape: (b,3,256,256). 即不改变维度。

// two conv2d+bn+relu. keep feature scale.
class ResidualBlockImpl : public torch::nn::Module {
public:ResidualBlockImpl(int in_channels);torch::Tensor forward(torch::Tensor x);
private:torch::nn::ReflectionPad2d pad1{ nullptr };torch::nn::Conv2d conv1{ nullptr };torch::nn::InstanceNorm2d bn1{ nullptr };torch::nn::ReLU relu1{ nullptr };torch::nn::ReflectionPad2d pad2{ nullptr };torch::nn::Conv2d conv2{ nullptr };torch::nn::InstanceNorm2d bn2{ nullptr };
};
TORCH_MODULE(ResidualBlock);
ResidualBlockImpl::ResidualBlockImpl(int in_channels) {pad1 = torch::nn::ReflectionPad2d(1);conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, in_channels, 3));bn1 = torch::nn::InstanceNorm2d(in_channels);relu1 = torch::nn::ReLU(true);pad2 = torch::nn::ReflectionPad2d(1);conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, in_channels, 3));bn2 = torch::nn::InstanceNorm2d(in_channels);register_module("block pad1", pad1);register_module("block conv1", conv1);register_module("block bn1", bn1);register_module("block pad2", pad2);register_module("block conv2", conv2);register_module("block bn2", bn2);
}
torch::Tensor ResidualBlockImpl::forward(torch::Tensor x) {x = pad1(x);x = conv1(x);x = bn1(x);x = relu1(x);x = pad2(x);x = conv2(x);x = bn2(x);return x;
}

 2.3 上采样模块

上采样模块由UpSample+Conv2d+InstanceNorm2d+ReLU组成。

用到两次上采样模块,维度变化(b,256,64,64)->(b,128,128,128)->(b,64,256,256)

/// 
/// 两次上采样,(b,256,64,64)->(b,128,128,128)->(b,64,256,256)
/// 
class UpSampleBlockImpl : public torch::nn::Module {
public:UpSampleBlockImpl(int in_channels, int out_channels);torch::Tensor forward(torch::Tensor x);
private:torch::nn::Upsample up{ nullptr };torch::nn::Conv2d conv{ nullptr };torch::nn::InstanceNorm2d bn{ nullptr };torch::nn::ReLU relu{ nullptr };
};
TORCH_MODULE(UpSampleBlock);
UpSampleBlockImpl::UpSampleBlockImpl(int in_channels, int out_channels) {up = torch::nn::Upsample(upsample_options(std::vector({2, 2})));  conv = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 3).padding(1));bn = torch::nn::InstanceNorm2d(out_channels);relu = torch::nn::ReLU(true);register_module("generator UpSampleBlock upsample", up);register_module("generator UpSampleBlock conv", conv);register_module("generator UpSampleBlock bn", bn);register_module("generator UpSampleBlock relu", relu);
}
torch::Tensor UpSampleBlockImpl::forward(torch::Tensor x) {x = up(x);x = conv(x);x = bn(x);x = relu(x);return x;
}

 最后再接一个conv2d,将通道数变成3即可输出生成的图像。

2.4 生成器代码

 

可以直接看forward函数,有5个步骤

(1)先是一个conv+bn+relu,预处理模块, size: (b,3,256,256) ->(b,64,256,256);

(2)然后是两次下采样,提取特征,size: (b,64,256,256) - > (b,128,128,128) -> (b,256,64,64);

(3)再接多个残差块,提取特征, size: (b,256,64,64) -> (b,256,64,64);

(4)上采样,size: (b,256,64,64)->(b,128,128,128)->(b,64,256,256);

(5)最后接一个输出层,即conv2d+bn+relu,size: (b,64,256,256) -> (b,3,256,256);

/// 
/// 下采样,res_blocks,上采样,output layer.
/// 
class GeneratorResNetImpl : public torch::nn::Module {
public:GeneratorResNetImpl(std::vector input_shape, int num_residual_blocks);torch::Tensor forward(torch::Tensor x);
private:torch::nn::Sequential _make_layer(int in_channels, int blocks);torch::nn::ReflectionPad2d pad1{ nullptr };torch::nn::Conv2d conv1{ nullptr };torch::nn::InstanceNorm2d bn1{ nullptr };torch::nn::ReLU relu1{ nullptr };// down DownSample down1{ nullptr };DownSample down2{ nullptr };// restorch::nn::Sequential res_blocks = torch::nn::Sequential();// upUpSampleBlock up1{ nullptr };UpSampleBlock up2{ nullptr };// output layertorch::nn::ReflectionPad2d pad2{ nullptr };torch::nn::Conv2d conv2{ nullptr };torch::nn::Tanh tanh2{ nullptr };
};
TORCH_MODULE(GeneratorResNet);torch::nn::Sequential GeneratorResNetImpl::_make_layer(int in_channels, int blocks)
{torch::nn::Sequential layers;for (int i = 0; i < blocks; i++) {layers->push_back(ResidualBlock(in_channels));}return layers;
}
GeneratorResNetImpl::GeneratorResNetImpl(std::vector input_shape, int num_residual_blocks)
{int channels = input_shape[0];  // 3int out_channels = 64;// 1, conv+bn+relu. (256+6-7+2*0)/1+1 = 256pad1 = torch::nn::ReflectionPad2d(channels);conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(channels, out_channels, 7));bn1 = torch::nn::InstanceNorm2d(out_channels);relu1 = torch::nn::ReLU(true);int in_channels = out_channels;// 2, Down sampling: 通过conv2d两次下采样,并且double channelsdown1 = DownSample(in_channels, out_channels*2);down2 = DownSample(out_channels * 2, out_channels*4);in_channels = out_channels * 4;  // 256 = 64*4// 3, Residual blocks: keep feature scale and channel unchange.res_blocks = _make_layer(in_channels, num_residual_blocks);  // (b,256,64,64)// 4, Up sampling: up+conv+bn+relu. halve channels and keep feature scale unchange.up1 = UpSampleBlock(in_channels, in_channels/2);  // (b,128,128,128)up2 = UpSampleBlock(in_channels / 2, in_channels / 4);  // (b,64,256,256)in_channels = in_channels / 4;  // 64// 5, output layer: pad+conv+tanh. change channels and keep feature scale unchange.pad2 = torch::nn::ReflectionPad2d(channels);  // 3conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, channels, 7));  // (b,64,256,256)->(b,3,256,256)tanh2 = torch::nn::Tanh();register_module("generator pad1", pad1);register_module("generator conv1", conv1);  // 一定要注册,不然不会使用cudaregister_module("generator bn1", bn1);register_module("generator relu1", relu1);register_module("generator down1", down1);register_module("generator down2", down2);register_module("generator res_blocks", res_blocks);register_module("generator up1", up1);register_module("generator up2", up2);register_module("generator pad2", pad2);register_module("generator conv2", conv2);register_module("generator tanh2", tanh2);
}
torch::Tensor GeneratorResNetImpl::forward(torch::Tensor x) {  // (b,3,256,256)// 1, conv+bn+relu. (256+6-7+2*0)/1+1 = 256x = pad1(x);x = conv1(x);x = bn1(x);x = relu1(x);   // (b,64,256,256)// 2, Down sampling: 通过conv2d两次下采样,并且double channelsx = down1(x);  // (b,128,128,128)x = down2(x);  // (b,256,64,64)// 3, Residual blocks: keep feature scale and channel unchange.x = res_blocks->forward(x);  // (b,256,64,64)// 4, Up sampling: up+conv+bn+relu. halve channels and keep feature scale unchange.x = up1(x);  // (b,128,128,128)x = up2(x);  // (b,64,256,256)// 5, output layer: pad+conv+tanh. change channels and keep feature scale unchange.x = pad2(x);x = conv2(x);x = tanh2(x);  // (b, 3, 256, 256)std::cout << x.sizes() << std::endl;return x;
}

3. 判别器

输入的是生成图图片(b,3,256,256),经过5次卷积,输出的是判别分数(b,1,16,16).

3.1 判别器组件

判别器组件是由conv2d+InstanceNorm2d+relu组成. 改变通道和scale.

/// 
/// Conv2d + bn + relu
/// 其中kernel_size设置成4,跟patchGan有关。
/// 
class DiscriminatorBlockImpl : public torch::nn::Module {
public:DiscriminatorBlockImpl(int in_channels, int out_channels, bool normalize = true);torch::Tensor forward(torch::Tensor x);
private:bool normalize = true;torch::nn::Conv2d conv{ nullptr };torch::nn::InstanceNorm2d bn{ nullptr };torch::nn::LeakyReLU relu{ nullptr };
};
TORCH_MODULE(DiscriminatorBlock);
DiscriminatorBlockImpl::DiscriminatorBlockImpl(int in_channels, int out_channels, bool normalize) {this->normalize = normalize;conv = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 4).stride(2).padding(1));if (normalize) bn = torch::nn::InstanceNorm2d(out_channels);relu = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));register_module("DiscriminatorBlock conv", conv);if (normalize) register_module("DiscriminatorBlock bn", bn);register_module("DiscriminatorBlock relu", relu);
}
torch::Tensor DiscriminatorBlockImpl::forward(torch::Tensor x) {x = conv(x);if (this->normalize)x = bn(x);x = relu(x);return x;
}

 3. 2 判别器

// (b,3,256,256)->(b,512,16,16)
torch::nn::Sequential _make_discriminator_blocks(int in_channels, int out_channels) {torch::nn::Sequential layers;layers->push_back(DiscriminatorBlock(in_channels, out_channels, false));layers->push_back(DiscriminatorBlock(out_channels, out_channels*2, true));layers->push_back(DiscriminatorBlock(out_channels * 2, out_channels * 4, true));layers->push_back(DiscriminatorBlock(out_channels * 4, out_channels * 8, true));return layers;
}
class DiscriminatorImpl : public torch::nn::Module {
public:DiscriminatorImpl(std::vector input_shape);torch::Tensor forward(torch::Tensor x);
public:std::vector output_shape_hw;//std::vector output_shape;
private:torch::nn::Sequential discriminator_blocks{ nullptr };torch::nn::ZeroPad2d pad{ nullptr };torch::nn::Conv2d conv{ nullptr };
};
TORCH_MODULE(Discriminator);
DiscriminatorImpl::DiscriminatorImpl(std::vector input_shape) {int channels = input_shape[0], height = input_shape[1], width = input_shape[2];// Calculate output shape of image discriminator (PatchGAN)this->output_shape_hw = { 1, height / int(pow(2,4)), width / int(pow(2,4)) };  // 外部调用,//this->output_shape = std::vector({ 1, height / int(pow(2,4)), width / int(pow(2,4)) });// 1, dis blocksdiscriminator_blocks = _make_discriminator_blocks(channels, 64);  // (b,512,16,16)// 2, zeropadpad = torch::nn::ZeroPad2d(torch::nn::ZeroPad2dOptions({ 1, 0, 1, 0 }));  // left,right,up,down// 3, convconv = torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 1, 4).padding(1));register_module("Discriminator discriminator_blocks", discriminator_blocks);register_module("Discriminator pad", pad);register_module("Discriminator conv", conv);
}
torch::Tensor DiscriminatorImpl::forward(torch::Tensor x) {  // (b,3,256,256)x = discriminator_blocks->forward(x);  // (b,3,256,256)->(b,512,16,16)x = pad(x);  // (b,512,17,17)x = conv(x);  // (b,1,16,16)std::cout << x.sizes() << std::endl;return x;
}

 4. 训练

4.1 输入数据 

real_A和real_B分别是领域A和领域B图片,valid和fake分别是全1和全0矩阵。

real_A和real_B size: (b,3,256,256);

valid和fake size: (b,1,16,16).

// Set model input:
torch::Tensor real_A = batch.data.toType(torch::kF32).to(torch::kCUDA);  // (b,3,256,256)
torch::Tensor real_B = batch.target.toType(torch::kF32).to(torch::kCUDA);  // (b,3,256,256)
torch::Tensor valid = torch::ones({ real_A.size(0), D_A->output_shape_hw.at(0), D_A->output_shape_hw.at(1), D_A->output_shape_hw.at(2) }, torch::kF32).to(torch::kCUDA);  // (32,1,16,16). 
torch::Tensor fake = torch::zeros({ real_A.size(0), D_A->output_shape_hw.at(0), D_A->output_shape_hw.at(1), D_A->output_shape_hw.at(2) }, torch::kF32).to(torch::kCUDA);  // (32,1,16,16). 

4.2 生成器loss函数结构图 

/*
----------------------Train Generators
----------------------
*/
// 1, Identity loss: cycGan可加可不加,加上identity loss生成的效果更好。
// 生成器G用来生成y风格图像,那么把y送入G,应该仍然生成y,G(y) = y,只有这样才能保证具有生成y风格的能力。
// 如果不加该loss,那么生成器可能会自主地修改图像的色调,使得整体的颜色产生变化。
torch::Tensor loss_id_A = l1_loss_identity(G_BA(real_A), real_A);  // G_BA(A) = A, 保证生成的A接近A
torch::Tensor loss_id_B = l1_loss_identity(G_AB(real_B), real_B);  // G_AB(B) = B, 保证生成的B接近B
torch::Tensor loss_identity = (loss_id_A + loss_id_B) / 2;// 2, Gan loss: 让生成的图像更能称之为图像,也就是生成的图像更真实。但它不保证能生成到我们想要的图像。
torch::Tensor fake_B = G_AB(real_A);
torch::Tensor loss_GAN_AB = mse_loss_gan(D_B(fake_B), valid);  // 由A生成B, D_B分数越高越好,D_AB(G_AB(A)) = 1
torch::Tensor fake_A = G_BA(real_B);
torch::Tensor loss_GAN_BA = mse_loss_gan(D_A(fake_A), valid);  // # 由B生成A, D_A分数越高越好,D_BA(G_BA(B)) = 1
torch::Tensor loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2;// 3, Cycle loss: 保证生成器的输出图片与输入图片只是风格不同,而内容相同
torch::Tensor loss_cycle_A = l1_loss_cycle(G_BA(fake_B.detach()), real_A);  // G_BA(G_AB(A)) = A
torch::Tensor loss_cycle_B = l1_loss_cycle(G_AB(fake_A), real_B);  // G_BA(G_AB(A)) = A
torch::Tensor loss_cycle = (loss_cycle_A + loss_cycle_B) / 2;// total g loss: loss_gan + 10*loss_cycle + 5*loss_identity
torch::Tensor loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity;
loss_G.backward();

 4.3 判别器loss结构图


待续。。。

相关内容

热门资讯

美国不提安卓系统华为,迈向自主... 华为与美国:一场关于技术、市场与政策的较量在当今这个数字化的世界里,智能手机已经成为我们生活中不可或...
安卓系统怎么打开ppt,选择文... 你有没有遇到过这种情况:手里拿着安卓手机,突然需要打开一个PPT文件,却怎么也找不到方法?别急,今天...
谷歌退回到安卓系统,探索创新未... 你知道吗?最近科技圈可是炸开了锅,谷歌竟然宣布要退回到安卓系统!这可不是一个简单的决定,背后肯定有着...
安卓系统待机耗电多少,深度解析... 你有没有发现,手机电量总是不经用?尤其是安卓系统,有时候明明没怎么用,电量就“嗖”的一下子就下去了。...
小米主题安卓原生系统,安卓原生... 亲爱的手机控们,你是否曾为手机界面单调乏味而烦恼?想要给手机换换“衣服”,让它焕然一新?那就得聊聊小...
voyov1安卓系统,探索创新... 你有没有发现,最近你的手机是不是变得越来越流畅了?没错,我要说的就是那个让手机焕发青春的Vivo V...
电脑刷安卓tv系统,轻松打造智... 你有没有想过,家里的安卓电视突然变得卡顿,反应迟钝,是不是时候给它来个“大保健”了?没错,今天就要来...
安卓系统即将要收费,未来手机应... 你知道吗?最近有个大消息在科技圈里炸开了锅,那就是安卓系统可能要开始收费了!这可不是开玩笑的,这可是...
雷凌车载安卓系统,智能出行新体... 你有没有发现,现在的汽车越来越智能了?这不,我最近就体验了一把雷凌车载安卓系统的魅力。它就像一个聪明...
怎样拍照好看安卓系统,轻松拍出... 拍照好看,安卓系统也能轻松搞定!在这个看脸的时代,拍照已经成为每个人生活中不可或缺的一部分。无论是记...
安卓车机系统音频,安卓车机系统... 你有没有发现,现在越来越多的汽车都开始搭载智能车机系统了?这不,咱们就来聊聊安卓车机系统在音频方面的...
老苹果手机安卓系统,兼容与创新... 你手里那台老苹果手机,是不是已经陪你走过了不少风风雨雨?现在,它竟然还能装上安卓系统?这可不是天方夜...
安卓系统7.dns,优化网络连... 你有没有发现,你的安卓手机最近是不是有点儿“慢吞吞”的?别急,别急,让我来给你揭秘这可能与你的安卓系...
安卓手机系统怎么加速,安卓手机... 你有没有发现,你的安卓手机最近变得有点“慢吞吞”的?别急,别急,今天就来给你支几招,让你的安卓手机瞬...
小米note安卓7系统,探索性... 你有没有发现,手机更新换代的速度简直就像坐上了火箭呢?这不,小米Note这款手机,自从升级到了安卓7...
安卓和鸿蒙系统游戏,两大系统游... 你有没有发现,最近手机游戏界可是热闹非凡呢!安卓和鸿蒙系统两大巨头在游戏领域展开了一场激烈的较量。今...
安卓手机没有系统更,揭秘潜在风... 你有没有发现,现在安卓手机的品牌和型号真是五花八门,让人挑花了眼。不过,你知道吗?尽管市面上安卓手机...
充值宝带安卓系统,安卓系统下的... 你有没有发现,最近手机上的一款充值宝APP,在安卓系统上可是火得一塌糊涂呢!这不,今天就来给你好好扒...
安卓系统8.0镜像下载,轻松打... 你有没有想过,想要给你的安卓手机升级到最新的系统,却不知道从哪里下载那个神秘的安卓系统8.0镜像呢?...
安卓系统修改大全,全方位修改大... 你有没有想过,你的安卓手机其实是个大宝藏,里面藏着无数可以让你手机焕然一新的秘密?没错,今天就要来个...