top-k|涵盖18+ SOTA GAN实现,这个图像生成领域的PyTorch库火了
机器之心报道
作者:杜伟、陈萍
GAN 自从被提出后 , 便迅速受到广泛关注 。 我们可以将 GAN 分为两类 , 一类是无条件下的生成;另一类是基于条件信息的生成 。 近日 , 来自韩国浦项科技大学的硕士生在 GitHub 上开源了一个项目 , 提供了条件 / 无条件图像生成的代表性生成对抗网络(GAN)的实现 。

文章图片
近日 , 机器之心在 GitHub 上看到了一个非常有意义的项目 PyTorch-StudioGAN , 它是一个 PyTorch 库 , 提供了条件 / 无条件图像生成的代表性生成对抗网络(GAN)的实现 。 据主页介绍 , 该项目旨在提供一个统一的现代 GAN 平台 , 这样机器学习领域的研究者可以快速地比较和分析新思路和新方法等 。
该项目的作者为韩国浦项科技大学的硕士生 , 他的研究兴趣主要包括深度学习、机器学习和计算机视觉 。

文章图片
项目地址:https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
具体而言 , 该项目具有以下几个显著特征:
- 提供了大量 PyTorch 框架的 GAN 实现;
- 基于 CIFAR 10、Tiny ImageNet 和 ImageNet 数据集的 GAN 基准;
- 相较原始实现的更好的性能和更低的内存消耗;
- 提供完全最新 PyTorch 环境的预训练模型;
- 支持多 GPU(DP、DDP 和多节点 DDP)、混合精度、同步批归一化、LARS、Tensorboard 可视化和其他分析方法 。
【top-k|涵盖18+ SOTA GAN实现,这个图像生成领域的PyTorch库火了】

文章图片
此外 , 有网友询问是否可以将该项目用于图像之外的其他领域 。 作者表示可以 , 即使无法使用一些稳定器(如 diffaug、ada 等) , 依然可以通过调整 dataLoader 来训练自己的模型 。

文章图片
18+ SOTA GAN 实现
如下图所示 , 项目作者提供了 18 + 个 SOTA GAN 的实现 , 包括 DCGAN、LSGAN、GGAN、WGAN-WC、WGAN-GP、WGAN-DRA、ACGAN、ProjGAN、SNGAN、SAGAN、BigGAN、BigGAN-Deep、CRGAN、ICRGAN、LOGAN、DiffAugGAN、ADAGAN、ContraGAN 和 FreezeD 。

文章图片
cBN:条件批归一化;AC:辅助分类器;PD:Projection 判别器;CL:对比学习 。
其中 , 需要注意以下几点:
- G/D_type 表示将标签信息注入生成器或判别式的方式;
- EMA 表示生成器中应用更新后的指数移动平均线;
- Tiny ImageNet 数据集上的实验使用的是 ResNet 架构而不是 CNN 。

文章图片
环境要求
- Anaconda
- Python >= 3.6
- 6.0.0 <= Pillow <= 7.0.0
- scipy == 1.1.0
- sklearn
- seaborn
- h5py
- tqdm
- torch >= 1.6.0
- torchvision >= 0.7.0
- tensorboard
- 5.4.0 <= gcc <= 7.4.0
- torchlars
conda env create -f environment.yml -n studiogan
在 docker 中还可以采用以下方式:
docker pull mgkang/studiogan:latest
以下是创建名字为「studioGAN」容器的命令 , 同样也可以使用端口号为 6006 来连接 tensoreboard 。
docker run -it --gpus all --shm-size 128g -p 6006:6006 --name studioGAN -v /home/USER:/root/code --workdir /root/code mgkang/studiogan:latest /bin/bash
使用方法
使用 GPU 0 的情况下 , 在 CONFIG_PATH 中对于模型的训练「-t」和评估「-e」进行了定义:
CUDA_VISIBLE_DEVICES=0 python3 src/main.py -t -e -c CONFIG_PATH
在使用 GPU (0, 1, 2, 3) 和 DataParallel 情况下 , 在 CONFIG_PATH 中对于模型的训练「-t」和评估「-e」进行了定义:
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 src/main.py -t -e -c CONFIG_PATH
在 python3 src/main.py 程序中查看可用选项 , 通过 Tensorboard 可以监控 IS、FID、F_beta、Authenticity Accuracies 以及最大奇异值:
~ PyTorch-StudioGAN/logs/RUN_NAME>>> tensorboard --logdir=./ --port PORT
可视化以及分析生成图像
StudioGAN 支持图像可视化、k 最近邻分析、线性差值以及频率分析 。 所有的结果保存在「./figures/RUN_NAME/*.png」中 。
图像可视化的代码和示例如下:
CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -iv -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

文章图片
k 最近邻分析 , 这里固定 K=7 , 第一列中是生成的图像:
CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -knn -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

文章图片
线性插值(仅适用于有条件的 Big ResNet 模型 )的代码和示例如下:
CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -itp -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

文章图片
参考链接:https://www.reddit.com/r/MachineLearning/comments/lu9gen/p_pytorch_gan_library_that_provides/
推荐阅读
- IT|丰田15款全新电动车正式发布 涵盖轿车、SUV、皮卡、超跑
- OriginOS|vivo新系统OriginOS体验!全新UI设计,12月10号开始升级,涵盖47款机型
- IT|宝马启动年内第三次大规模OTA升级 涵盖约200万辆车
- 存储器|中微公司:在3D NAND芯片制造环节,公司的电容性等离子体刻蚀设备可应用于64层和128层的量产,同时公司根据存储器厂商的需求正在开发新一代能够
- 功能|涵盖不同年龄段十大区域 首个小鼠大脑代谢物图谱发表
- 视点·观察|这份“公司作息表”火到被举报:涵盖1300+公司 作息精准到部门
- 公司|金雷股份:目前,公司主轴产品已涵盖1.5MW-8MW多种主流机型,且能够生产的最大兆瓦数不断提升
- 功能|直真科技:5G垂直行业专网运维支撑平台主要为电信运营商以及行业客户提供5G专网端到端的运维解决方案,内容涵盖5G专网开通、故障监控、性能预警
- the|华盛顿特区扩大对亚马逊的诉讼范围 涵盖更多商品
- Tencent|腾讯把这游戏搞成了18+ 但我赌他不会做18+的内容