WGANGP¶
# CIFAR10实验
python wgangp_cifar10.py
# MNIST实验
python wgangp_mnist.py
# 玩具数据集实验
python wgangp_toy.py
# CIFAR10实验
python wgangp_cifar10.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_cifar10.pdparams #EVAL.pretrained_dis_model_path为从https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_cifar10.pdparams下载后模型地址
# MNIST实验
python wgangp_mnist.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_mnist.pdparams #EVAL.pretrained_dis_model_path为从https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_mnist.pdparams下载后模型地址
# 玩具数据集实验
python wgangp_toy.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_toy_8gaussians.pdparams #EVAL.pretrained_dis_model_path为从https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_toy_8gaussians.pdparams下载后模型地址
| 预训练模型 | 指标 |
|---|---|
| wgangp_cifar10_gen_pretrained.pdparams wgangp_cifar10_dis_pretrained.pdparams |
IS: 5.2 |
1. 背景简介¶
在数字图像处理和机器学习领域,生成对抗网络(GANs)因其卓越的图像生成能力而受到广泛关注。然而,传统的GAN架构在训练过程中可能会遇到不稳定的问题,尤其是在生成高分辨率或复杂场景的图像时。为了解决这些问题,研究人员提出了带有梯度惩罚的Wasserstein生成对抗网络(WGAN-GP),它不仅增强了训练过程的稳定性,还显著提升了生成图像的质量。
WGAN-GP通过改进损失函数来最小化真实数据分布与生成数据分布之间的差异,并引入梯度惩罚机制以确保训练过程中的平滑性和稳定性。这种优化方法克服了传统GAN中常见的模式崩溃问题,同时促进了更高效的训练和更逼真的图像生成。
2. 模型原理¶
WGAN-GP提出一种替代权重剪裁的方法:对评论者输入梯度的范数施加惩罚。在几乎无需超参数调整的情况下稳定训练多种GAN架构.
2.1 模型结构¶
WGAN-GP是一个条件对抗网络,包含了一个noise-to-image的生成器和一个CNN的判别器。下面显示了模型的整体结构。
-
Generator是一种卷积神经网络。 -
Discriminator是由卷积块组成的模型。输入图像,输出图像的真实性分数。
2.2 损失函数¶
判别器的损失函数采用了Wasserstein损失和梯度惩罚。其表达式为:
其中\(\mathbb{P}_g\)是生成器的分布,\(\mathbb{P}_r\)是真实数据的分布,\(\mathbb{P}_{\hat{x}}\)是来自\(\mathbb{P}_g\)和\(\mathbb{P}_r\)的混合插值样本。
生成器的损失函数是对抗性损失[\(- \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x})\)]。其表达式为:
其中\(\mathbb{P}_g\)是生成器的分布
3. 模型构建¶
接下来开始讲解如何使用PaddleScience框架实现WGAN-GP。以下内容仅对关键步骤进行阐述,其余细节请参考 API文档。
3.1 数据集介绍¶
数据集采用了Cifar10数据集、MNIST和玩具数据集(swissroll/8gaussians/25gaussians)。
Cifar10数据集包含60000张32x32彩色图像,共分为10个类别,每个类别6000张图像。
Cifar10数据集有3个版本
| Version | Size | md5sum |
|---|---|---|
| CIFAR-100 python | 161 MB | eb9058c3a382ffc7106e4002c42a8d85 |
| CIFAR-100 Matlab | 175 MB | 6a4bfa1dcd5c9453dda6bb54194911f4 |
| CIFAR-100 binary | 161 MB | 03b5dce01913d631647c71ecec9e9cb8 |
本实现使用的为CIFAR-100 python版本
MNIST数据集包含60000张28x28灰度图像,共分为10个类别,每个类别6000张图像。
玩具数据集
Swissroll:三维非线性流形数据集,呈现连续卷曲的螺旋结构,
8gaussians:二维合成数据集,包含八个对称分布的高斯簇,各簇中心均匀分布于圆周,
25gaussians:高密度高斯混合数据集,由25个规则排列的二维高斯分布构成,簇间距紧凑。
3.2 构建dataset API¶
由于Cifar10数据集由5个数据文件组成,由于数据集组织方式,我们无法直接使用PaddleScience内置的dataset API,所以先把所有数据读取出来,再使用ppsci.data.dataset.array_dataset.NamedArrayDataset。
下面给出Cifar10数据集读取的代码:
data_path传入的是CIFAR-10的路径。
下面给出dataloader的配置代码:
由于MNIST数据集无法直接使用PaddleScience内置的dataset API,所以先把所有数据读取出来,再使用ppsci.data.dataset.array_dataset.NamedArrayDataset。
下面给出MNIST数据集读取的代码:
下面给出dataloader的配置代码:
由于玩具数据集无法直接使用PaddleScience内置的dataset API,所以先把所有数据生成出来,再使用ppsci.data.dataset.array_dataset.NamedArrayDataset。
下面给出玩具数据集的生成代码
下面给出dataloader的配置代码:
3.3 模型构建¶
本案例的WGAN-GP没有被内置在PaddleScience中,需要额外实现,因此我们自定义了WganGpCifar10Generator和WganGpCifar10Discriminator、WganGpMnistGenerator和WganGpMnistDiscriminator、WganGpToyGenerator和WganGpToyDiscriminator。
模型的构建代码如下:
WganGpCifar10Generator和WganGpCifar10Discriminator
WganGpMnistGenerator和WganGpMnistDiscriminator
WganGpToyGenerator和WganGpToyDiscriminator
参数配置如下:
WganGpCifar10Generator和WganGpCifar10Discriminator
WganGpMnistGenerator和WganGpMnistDiscriminator
WganGpToyGenerator和WganGpToyDiscriminator
3.4 自定义loss¶
WGAN-GP的损失函数较复杂,需要我们自定义实现。PaddleScience提供了用于自定loss函数的API——ppsci.loss.FunctionalLoss。方法为先定义loss函数,再将函数名作为参数传给 FunctionalLoss。需要注意,自定义loss函数的输入输出需要是字典的格式。
3.4.1 Generator的loss¶
Cifar10_Generator的loss包含了对抗性损失和分类损失。这两项loss都有对应的权重,如果某一项 loss 的权重为 0,则表示训练中不添加该 loss 项。
MNIST_Generator的loss只包含了对抗性损失。
Toy_Generator的loss只包含了对抗性损失。
3.4.2 Discriminator的loss¶
Cifar10_Discriminator的loss包含了Wasserstein损失和梯度惩罚以及分类损失。其中,只有分类损失项有权重参数。
MNIST_Discriminator的loss包含了Wasserstein损失和梯度惩罚。
Toy_Discriminator的loss包含了Wasserstein损失和梯度惩罚。
3.5 约束构建¶
所有案例均使用ppsci.constraint.SupervisedConstraint构建约束。
构建代码如下:
针对Cifar10的实验
针对MNIST的实验
针对玩具数据集的实验
3.6 优化器构建¶
WGANGP使用Adam优化器,可直接调用ppsci.optimizer.Adam构建,代码如下:
针对Cifar10的实验
针对MNIST的实验
针对玩具数据集的实验
3.7 Solver构建¶
将构建好的模型、约束、优化器和其它参数传递给 ppsci.solver.Solver。
针对Cifar10的实验
针对MNIST的实验
针对玩具数据集的实验
3.8 模型训练¶
针对Cifar10的实验
针对MNIST的实验
针对玩具数据集的实验
3.9 自定义metric¶
案例中只有针对Cifar10的案例有评估指标为Inception Score,MNIST和Toy案例没有评估指标。由于metric为空会报错所以自定义了一个无效metric
所以我们额外实现了两个metric
PaddleScience提供了用于自定metric函数的API——ppsci.metric.FunctionalMetric。方法为先定义metric函数,再将函数名作为参数传给 FunctionalMetric。需要注意,自定义metric函数的输入输出需要是字典的格式。
Inception Score的实现代码如下:
invalid_metric的代码如下
3.10 Validator构建¶
本案例使用ppsci.validate.SupervisedValidator构建评估器。
针对Cifar10的实验
针对MNIST的实验
针对玩具数据集的实验
3.11 模型评估¶
将模型、评估器和权重路径传递给ppsci.solver.Solver后,通过solver.eval()启动评估。
针对Cifar10的实验
针对MNIST的实验
针对玩具数据集的实验
3.12 可视化¶
评估完成后,我们以图片的形式对结果进行可视化,代码如下:
针对Cifar10的实验
针对MNIST的实验
针对玩具数据集的实验
4. 完整代码¶
针对Cifar10的实验
import os
import platform
import hydra
import paddle
from functions import Cifar10DisFuncs
from functions import Cifar10GenFuncs
from functions import InceptionScore
from functions import load_cifar10
from functions import show_save_image
from omegaconf import DictConfig
from wgangp_cifar10_model import WganGpCifar10Discriminator
from wgangp_cifar10_model import WganGpCifar10Generator
import ppsci
from ppsci.optimizer.lr_scheduler import Linear
from ppsci.utils import logger
os.environ["FLAGS_cudnn_deterministic"] = "1"
def evaluate(cfg: DictConfig):
# set model
generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
cfg.EVAL.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))
# set Loss
generator_funcs = Cifar10GenFuncs(
**cfg["LOSS"]["gen"], discriminator_model=discriminator_model
)
eval_inception_score = InceptionScore(**cfg["EVAL"]["inceptionscore"])
# set data
inputs, labels = load_cifar10(**cfg["DATA"])
valid_dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
"label": labels,
},
"batch_size": cfg["EVAL"]["batch_size"],
"use_shared_memory": cfg["EVAL"]["use_shared_memory"],
"num_workers": cfg["EVAL"]["num_workers"]
if platform.system() != "Windows"
else 0,
}
# set validator
validator = ppsci.validate.SupervisedValidator(
dataloader_cfg=valid_dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
output_expr={"labels": lambda out: out["labels"]},
metric={
"IS": ppsci.metric.FunctionalMetric(eval_inception_score.inception_score),
},
name="val",
)
validator_dict = {validator.name: validator}
# initialize solver
solver = ppsci.solver.Solver(
model=generator_model,
validator=validator_dict,
pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
output_dir=cfg.output_dir,
)
# evaluation
solver.eval()
# visualization
if cfg.VIS.vis:
with solver.no_grad_context_manager(True):
generator_model.eval()
for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
if batch_idx + 1 > cfg.VIS.batch:
break
fake_image = generator_model(input_)["fake_data"]
show_save_image(
fake_image[0],
f"{cfg.output_dir}/image{batch_idx}.png",
)
print(f"The visualizations are saved to {cfg.output_dir}")
def train(cfg: DictConfig):
# set model
generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
cfg.TRAIN.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))
# set Loss
generator_funcs = Cifar10GenFuncs(
**cfg["LOSS"]["gen"], discriminator_model=discriminator_model
)
discriminator_funcs = Cifar10DisFuncs(
**cfg["LOSS"]["dis"], discriminator_model=discriminator_model
)
# set dataloader
inputs, labels = load_cifar10(**cfg["DATA"])
dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
"label": labels,
},
"sampler": {
**cfg["TRAIN"]["sampler"],
},
"batch_size": cfg["TRAIN"]["batch_size"],
"use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
"num_workers": cfg["TRAIN"]["num_workers"],
"drop_last": cfg["TRAIN"]["drop_last"],
}
# set constraint
constraint_generator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
output_expr={"labels": lambda out: out["labels"]},
name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}
constraint_discriminator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
output_expr={"labels": lambda out: out["labels"]},
name="constraint_discriminator",
)
constraint_discriminator_dict = {
constraint_discriminator.name: constraint_discriminator
}
# set optimizer
lr_scheduler_generator = Linear(**cfg["TRAIN"]["lr_scheduler_gen"])()
lr_scheduler_discriminator = Linear(**cfg["TRAIN"]["lr_scheduler_dis"])()
optimizer_generator = ppsci.optimizer.Adam(
learning_rate=lr_scheduler_generator,
beta1=cfg["TRAIN"]["optimizer"]["beta1"],
beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_discriminator = ppsci.optimizer.Adam(
learning_rate=lr_scheduler_discriminator,
beta1=cfg["TRAIN"]["optimizer"]["beta1"],
beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_generator = optimizer_generator(generator_model)
optimizer_discriminator = optimizer_discriminator(discriminator_model)
# initialize solver
solver_generator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "generator"),
constraint=constraint_generator_dict,
optimizer=optimizer_generator,
epochs=cfg.TRAIN.epochs_gen,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "discriminator"),
constraint=constraint_discriminator_dict,
optimizer=optimizer_discriminator,
epochs=cfg.TRAIN.epochs_dis,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
# train
for i in range(cfg.TRAIN.epochs):
logger.message(f"\nEpoch: {i + 1}\n")
optimizer_discriminator.clear_grad()
solver_discriminator.train()
optimizer_generator.clear_grad()
solver_generator.train()
# save model weight
paddle.save(
generator_model.state_dict(),
os.path.join(cfg.output_dir, "model_generator.pdparams"),
)
paddle.save(
discriminator_model.state_dict(),
os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
)
@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_cifar10.yaml")
def main(cfg: DictConfig):
ppsci.utils.misc.set_random_seed(cfg["seed"])
logger.init_logger(
cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
)
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
if __name__ == "__main__":
main()
针对MNIST的实验
import os
import platform
import hydra
import paddle
from functions import MnistDisFuncs
from functions import MnistGenFuncs
from functions import invalid_metric
from functions import load_mnist
from functions import show_mnist
from omegaconf import DictConfig
from wgangp_mnist_model import WganGpMnistDiscriminator
from wgangp_mnist_model import WganGpMnistGenerator
import ppsci
from ppsci.utils import logger
def evaluate(cfg: DictConfig):
# set model
generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])
if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
cfg.EVAL.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))
# set Loss
generator_funcs = MnistGenFuncs(discriminator_model=discriminator_model)
# set dataloader
inputs = load_mnist(**cfg["DATA"])
valid_dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
},
"batch_size": cfg["EVAL"]["batch_size"],
"use_shared_memory": cfg["EVAL"]["use_shared_memory"],
"num_workers": cfg["EVAL"]["num_workers"]
if platform.system() != "Windows"
else 0,
}
# set validator
validator = ppsci.validate.SupervisedValidator(
dataloader_cfg=valid_dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
metric={
"MAE": ppsci.metric.FunctionalMetric(invalid_metric),
},
name="val",
)
validator_dict = {validator.name: validator}
# initialize solver
solver = ppsci.solver.Solver(
model=generator_model,
validator=validator_dict,
pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
output_dir=cfg.output_dir,
)
# evaluation
solver.eval()
# visualization
if cfg.VIS.vis:
with solver.no_grad_context_manager(True):
for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
if batch_idx + 1 > cfg.VIS.batch:
break
fake_data = generator_model(input_)["fake_data"]
show_mnist(
fake_data[0],
f"{cfg.output_dir}/image{batch_idx}.png",
)
show_mnist(
input_["real_data"][0],
f"{cfg.output_dir}/image_real_{batch_idx}.png",
)
print(f"The visualizations are saved to {cfg.output_dir}")
def train(cfg: DictConfig):
# set model
generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])
if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
cfg.TRAIN.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))
# set Loss
generator_funcs = MnistGenFuncs(discriminator_model=discriminator_model)
discriminator_funcs = MnistDisFuncs(
**cfg["LOSS"]["dis"], discriminator_model=discriminator_model
)
# set dataloader
inputs = load_mnist(**cfg["DATA"])
dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
},
"sampler": {
**cfg["TRAIN"]["sampler"],
},
"batch_size": cfg["TRAIN"]["batch_size"],
"use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
"num_workers": cfg["TRAIN"]["num_workers"],
"drop_last": cfg["TRAIN"]["drop_last"],
}
# set constraint
constraint_generator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}
constraint_discriminator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
output_expr={"real_data": lambda out: out["real_data"]},
name="constraint_discriminator",
)
constraint_discriminator_dict = {
constraint_discriminator.name: constraint_discriminator
}
# set optimizer
optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])
optimizer_generator = optimizer(generator_model)
optimizer_discriminator = optimizer(discriminator_model)
# initialize solver
solver_generator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "generator"),
constraint=constraint_generator_dict,
optimizer=optimizer_generator,
epochs=cfg.TRAIN.epochs_gen,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "discriminator"),
constraint=constraint_discriminator_dict,
optimizer=optimizer_discriminator,
epochs=cfg.TRAIN.epochs_dis,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
# train
for i in range(cfg.TRAIN.epochs):
logger.message(f"\nEpoch: {i + 1}\n")
optimizer_discriminator.clear_grad()
solver_discriminator.train()
optimizer_generator.clear_grad()
solver_generator.train()
# save model weight
paddle.save(
generator_model.state_dict(),
os.path.join(cfg.output_dir, "model_generator.pdparams"),
)
paddle.save(
discriminator_model.state_dict(),
os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
)
@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_mnist.yaml")
def main(cfg: DictConfig):
ppsci.utils.misc.set_random_seed(cfg["seed"])
logger.init_logger(
cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
)
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
if __name__ == "__main__":
main()
针对玩具数据集的实验
import os
import platform
import hydra
import paddle
from functions import ToyDisFuncs
from functions import ToyGenFuncs
from functions import generate_toy_image
from functions import invalid_metric
from functions import load_toy_data
from omegaconf import DictConfig
from wgangp_toy_model import WganGpToyDiscriminator
from wgangp_toy_model import WganGpToyGenerator
import ppsci
from ppsci.utils import logger
def evaluate(cfg: DictConfig):
# set model
discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])
if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
cfg.EVAL.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))
generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])
# set Loss
generator_funcs = ToyGenFuncs(discriminator_model=discriminator_model)
# set dataloader
inputs = load_toy_data(**cfg["DATA"])
valid_dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
},
"batch_size": cfg["EVAL"]["batch_size"],
"use_shared_memory": cfg["EVAL"]["use_shared_memory"],
"num_workers": cfg["EVAL"]["num_workers"]
if platform.system() != "Windows"
else 0,
}
# set validator
validator = ppsci.validate.SupervisedValidator(
dataloader_cfg=valid_dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
metric={"invalid_metric": ppsci.metric.FunctionalMetric(invalid_metric)},
name="val",
)
validator_dict = {validator.name: validator}
# initialize solver
solver = ppsci.solver.Solver(
model=generator_model,
validator=validator_dict,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
output_dir=cfg.output_dir,
)
# eval
solver.eval()
# visualization
if cfg.VIS.vis:
with solver.no_grad_context_manager(True):
input_, _, _ = next(iter(validator.data_loader))
real_data = input_["real_data"]
generate_toy_image(
true_dist=real_data,
discriminator=discriminator_model,
path=os.path.join(cfg.output_dir, "image.png"),
)
print(f"The visualizations are saved to {cfg.output_dir}")
def train(cfg: DictConfig):
# set model
generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])
if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
cfg.TRAIN.pretrained_dis_model_path
):
discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))
# set Loss
generator_funcs = ToyGenFuncs(discriminator_model=discriminator_model)
discriminator_funcs = ToyDisFuncs(
**cfg["LOSS"]["dis"], discriminator_model=discriminator_model
)
# set dataloader
inputs = load_toy_data(**cfg["DATA"])
dataloader_cfg = {
"dataset": {
"name": cfg["EVAL"]["dataset"]["name"],
"input": inputs,
},
"sampler": {
**cfg["TRAIN"]["sampler"],
},
"batch_size": cfg["TRAIN"]["batch_size"],
"use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
"num_workers": cfg["TRAIN"]["num_workers"],
"drop_last": cfg["TRAIN"]["drop_last"],
}
# set constraint
constraint_generator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}
constraint_discriminator = ppsci.constraint.SupervisedConstraint(
dataloader_cfg=dataloader_cfg,
loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
output_expr={"real_data": lambda out: out["real_data"]},
name="constraint_discriminator",
)
constraint_discriminator_dict = {
constraint_discriminator.name: constraint_discriminator
}
# set optimizer
optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])
optimizer_generator = optimizer(generator_model)
optimizer_discriminator = optimizer(discriminator_model)
# initialize solver
solver_generator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "generator"),
constraint=constraint_generator_dict,
optimizer=optimizer_generator,
epochs=cfg.TRAIN.epochs_gen,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
model=generator_model,
output_dir=os.path.join(cfg.output_dir, "discriminator"),
constraint=constraint_discriminator_dict,
optimizer=optimizer_discriminator,
epochs=cfg.TRAIN.epochs_dis,
iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
# train
for i in range(cfg.TRAIN.epochs):
logger.message(f"\nEpoch: {i + 1}\n")
optimizer_discriminator.clear_grad()
solver_discriminator.train()
optimizer_generator.clear_grad()
solver_generator.train()
# save model weight
paddle.save(
generator_model.state_dict(),
os.path.join(cfg.output_dir, "model_generator.pdparams"),
)
paddle.save(
discriminator_model.state_dict(),
os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
)
@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_toy.yaml")
def main(cfg: DictConfig):
ppsci.utils.misc.set_random_seed(cfg["seed"])
logger.init_logger(
cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
)
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")
if __name__ == "__main__":
main()