【深度学习进阶之路】---- GAN 代码复现之 DCLGAN

发布于:2023-01-21 ⋅ 阅读:(532) ⋅ 点赞:(0)

——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
论文题目:Dual Contrastive Learning for Unsupervised Image-to-Image Translation
论文地址:https://arxiv.org/abs/2104.07689
在这里插入图片描述
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——

关键词

深度学习;图像转换;代码复现;DCLGAN;Dual Contrasive Learning
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——

一、工作准备

1. 源码下载

Github源代码:https://github.com/JunlinHan/DCLGAN

选择 Code → Download ZIP 下载即可。
在这里插入图片描述
之后对文件进行解压,解压后包含如下文件:
在这里插入图片描述
☆☆☆----解压后在根目录中分别创建如下两个文件,分别是:
checkpoints文件(用于存放训练好的模型权重文件)
results文件(用于存放测试后结果)

2. 数据集下载

本文将利用作者提供的 maps 数据集进行模型训练和测试。
数据集下载链接:https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/
在这里插入图片描述
将数据集下载后放入datasets文件并解压,数据集内具体内容如下:
在这里插入图片描述
其中,A表示遥感影像,B表示对应的二维地图,如下:
在这里插入图片描述

3. 配置必要的环境

可以在environment.yml文件中查看具体所需的环境配置,如下:
在这里插入图片描述
其中,Python的版本选用3.6及以上版本都可以,本文所用的Python版本为3.9.12。
对于Conda用户,可以利用下列代码进行环境创建:

conda env create -f 文件存放路径/environment.yml

在这里插入图片描述
之后激活此环境

conda activate your_env_name (虚拟环境名称)

本篇Blog将调用远程服务器对代码进行调试和复现,关于本地Windows系统上Pycharm连接远程Linux服务器的相关操作可参考上一篇博客:【深度学习进阶之路】----Pycharm连接远程服务器进行代码调试与开发
**插播:**为了方便在Pycharm中实现远程服务器文件目录的可视化,可以在Pycharm中选择Tools–>Deployment–>Browse Remote Host
在这里插入图片描述
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——

二、训练数据集

1. 配置训练文件

① 在Pycharm中,点击Run–>Edit Configurations,进行配置训练文件。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
其中,--dataroot ./datasets/maps表示数据集存放位置;--name maps_DCL表示将在checkpoints文件夹中新建一个名为maps_DCL的文件夹,用于存放训练好的权重;--model dcl表示使用的模型为dcl。

2. 相关参数的修改

① 在train_options.py中修改训练epoch及学习率,注意epoch要修改两处,二者之和便为总的训练epoch。
在这里插入图片描述
② 在base_options.py中修改batch size及图片尺寸
在这里插入图片描述
当我们进行了上述操作后,便可美美哒运行train.py啦。
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
**插播:**当博主运行train.py时,出现以下两条警告信息,虽然不影响模型训练,但还是不想看到红色字体。
在这里插入图片描述

UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.

出现此条UserWarning,是因为pytorch不同版本进行更新迭代时引起的警告,某些参数被取代了,解决方案:

self.criterionSim = torch.nn.L1Loss('sum').to(self.device)
改为:
self.criterionSim = torch.nn.L1Loss(reduction='sum').to(self.device)
UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
  "Argument interpolation should be of type InterpolationMode instead of int. "

出现此条UserWarning,是torchvision和pillow不兼容导致的,我的环境里torchvision=0.11.3 and pillow=6.1.0,即使我把pillow升级到8.3.1,依然有warning。那只能降低torchvision了,但是torchvision的版本号一般都是和pytorch绑定好的,我们需要不依赖torch来更改torchvison的版本,这可以通过以下指令实现:

self.criterionSim = torch.nn.L1Loss('sum').to(self.device)
改为:
self.criterionSim = torch.nn.L1Loss(reduction='sum').to(self.device)

——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
OK,准备就绪,代码已经可以完美的 run了。
在这里插入图片描述
大概经过 339s×200epoches≈18.83h 的训练,模型已基本训练完成,如下:
在这里插入图片描述
Oh Yeah, Process finished with exit code 0 !!!
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——

三、测试训练好的模型

训练好的模型,存放在根目录文件夹 checkpoints 中,如下:
在这里插入图片描述

1. 配置测试文件

之后配置测试文件(同“二、2”),便可运行test.py文件,如下:
在这里插入图片描述
其中,–dataroot ./datasets/maps表示数据集存放位置;–name maps_DCL表示将在results文件夹中新建一个名为maps_DCL的文件夹,用于存放测试结果。
在这里插入图片描述
Oh Yeah, Process finished with exit code 0 !!!

2. 结果展示

在results文件夹中,点击index.html即可在线查看模型测试结果,如下:
在这里插入图片描述
放几张图,效果嘛,自行体会吧~~~~~~
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——

参考

本篇博客在代码复现过程中,参考了以下几位大神的文章,在此拜谢。

  1. 使用CycleGAN训练自己制作的数据集,通俗教程,快速上手
  2. Cycle GAN(复现)—小白笔记
本文含有隐藏内容,请 开通VIP 后查看

网站公告

今日签到

点亮在社区的每一天
去签到