跳转至

WGANGP

Note

  1. 运行之前将Cifar10下载,并更新wgangp_cifar10.yaml中的data_path
  2. 运行之前将MINST下载,并更新wgangp_mnist.yaml中的data_path
# CIFAR10实验
python wgangp_cifar10.py
# MNIST实验
python wgangp_mnist.py
# 玩具数据集实验
python wgangp_toy.py
# CIFAR10实验
python wgangp_cifar10.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_cifar10.pdparams #EVAL.pretrained_dis_model_path为从https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_cifar10.pdparams下载后模型地址
# MNIST实验
python wgangp_mnist.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_mnist.pdparams #EVAL.pretrained_dis_model_path为从https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_mnist.pdparams下载后模型地址
# 玩具数据集实验
python wgangp_toy.py mode=eval EVAL.pretrained_gen_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_generator_toy_8gaussians.pdparams #EVAL.pretrained_dis_model_path为从https://paddle-org.bj.bcebos.com/paddlescience/models/wgangp/model_discriminator_toy_8gaussians.pdparams下载后模型地址
预训练模型 指标
wgangp_cifar10_gen_pretrained.pdparams
wgangp_cifar10_dis_pretrained.pdparams
IS: 5.2

1. 背景简介

在数字图像处理和机器学习领域,生成对抗网络(GANs)因其卓越的图像生成能力而受到广泛关注。然而,传统的GAN架构在训练过程中可能会遇到不稳定的问题,尤其是在生成高分辨率或复杂场景的图像时。为了解决这些问题,研究人员提出了带有梯度惩罚的Wasserstein生成对抗网络(WGAN-GP),它不仅增强了训练过程的稳定性,还显著提升了生成图像的质量。

WGAN-GP通过改进损失函数来最小化真实数据分布与生成数据分布之间的差异,并引入梯度惩罚机制以确保训练过程中的平滑性和稳定性。这种优化方法克服了传统GAN中常见的模式崩溃问题,同时促进了更高效的训练和更逼真的图像生成。

2. 模型原理

WGAN-GP提出一种替代权重剪裁的方法:对评论者输入梯度的范数施加惩罚。在几乎无需超参数调整的情况下稳定训练多种GAN架构.

2.1 模型结构

WGAN-GP是一个条件对抗网络,包含了一个noise-to-image的生成器和一个CNN的判别器。下面显示了模型的整体结构。

    noise===>generator===>fake_image==
                                      ==>discriminator===>Wasserstein Loss+Gradient Penalty
                               image==
  • Generator是一种卷积神经网络。

  • Discriminator是由卷积块组成的模型。输入图像,输出图像的真实性分数。

2.2 损失函数

判别器的损失函数采用了Wasserstein损失和梯度惩罚。其表达式为:

\[ L_d = \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}} D(\tilde{x}) - \underset{x \sim \mathbb{P}_r}{\mathbb{E}}D(x) + \lambda \underset{\hat{x} \sim \mathbb{P}_{\hat{x}}}{\mathbb{E}} \left[ \left( \| \nabla_{\hat{x}} D(\hat{x}) \|_2 - 1 \right)^2 \right] \]

其中\(\mathbb{P}_g\)是生成器的分布,\(\mathbb{P}_r\)是真实数据的分布,\(\mathbb{P}_{\hat{x}}\)是来自\(\mathbb{P}_g\)\(\mathbb{P}_r\)的混合插值样本。

生成器的损失函数是对抗性损失[\(- \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x})\)]。其表达式为:

\[ L_g = - \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x}) \]

其中\(\mathbb{P}_g\)是生成器的分布

3. 模型构建

接下来开始讲解如何使用PaddleScience框架实现WGAN-GP。以下内容仅对关键步骤进行阐述,其余细节请参考 API文档

3.1 数据集介绍

数据集采用了Cifar10数据集、MNIST和玩具数据集(swissroll/8gaussians/25gaussians)。

Cifar10数据集包含60000张32x32彩色图像,共分为10个类别,每个类别6000张图像。

Cifar10数据集有3个版本

Version Size md5sum
CIFAR-100 python 161 MB eb9058c3a382ffc7106e4002c42a8d85
CIFAR-100 Matlab 175 MB 6a4bfa1dcd5c9453dda6bb54194911f4
CIFAR-100 binary 161 MB 03b5dce01913d631647c71ecec9e9cb8

本实现使用的为CIFAR-100 python版本

MNIST数据集包含60000张28x28灰度图像,共分为10个类别,每个类别6000张图像。

玩具数据集

Swissroll:三维非线性流形数据集,呈现连续卷曲的螺旋结构,

8gaussians:二维合成数据集,包含八个对称分布的高斯簇,各簇中心均匀分布于圆周,

25gaussians:高密度高斯混合数据集,由25个规则排列的二维高斯分布构成,簇间距紧凑。

3.2 构建dataset API

由于Cifar10数据集由5个数据文件组成,由于数据集组织方式,我们无法直接使用PaddleScience内置的dataset API,所以先把所有数据读取出来,再使用ppsci.data.dataset.array_dataset.NamedArrayDataset

下面给出Cifar10数据集读取的代码:

def load_cifar10(input_keys, label_keys, data_path):
    datas, labels = unpickle(data_path)
    datas = datas.astype("float32")
    datas_ = ((datas / 256.0) - 0.5) * 2
    random_uniform = np.random.uniform(size=[50000, 3072], low=0.0, high=1.0 / 128)
    datas_ = (datas_ + random_uniform).astype("float32")
    labels_ = np.array(labels, dtype="int32")
    labels = {label_keys[0]: datas_}
    datas = {input_keys[0]: labels_}
    return datas, labels
其中data_path传入的是CIFAR-10的路径。

下面给出dataloader的配置代码:

inputs, labels = load_cifar10(**cfg["DATA"])
dataloader_cfg = {
    "dataset": {
        "name": cfg["EVAL"]["dataset"]["name"],
        "input": inputs,
        "label": labels,
    },
    "sampler": {
        **cfg["TRAIN"]["sampler"],
    },
    "batch_size": cfg["TRAIN"]["batch_size"],
    "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
    "num_workers": cfg["TRAIN"]["num_workers"],
    "drop_last": cfg["TRAIN"]["drop_last"],
}

由于MNIST数据集无法直接使用PaddleScience内置的dataset API,所以先把所有数据读取出来,再使用ppsci.data.dataset.array_dataset.NamedArrayDataset

下面给出MNIST数据集读取的代码:

def load_mnist(
    data_path,
    input_keys,
):
    with gzip.open(data_path, "rb") as f:
        train_data, _, _ = pickle.load(f, encoding="latin1")
    data, _ = train_data
    data = {input_keys[0]: data}
    return data

下面给出dataloader的配置代码:

inputs = load_mnist(**cfg["DATA"])
dataloader_cfg = {
    "dataset": {
        "name": cfg["EVAL"]["dataset"]["name"],
        "input": inputs,
    },
    "sampler": {
        **cfg["TRAIN"]["sampler"],
    },
    "batch_size": cfg["TRAIN"]["batch_size"],
    "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
    "num_workers": cfg["TRAIN"]["num_workers"],
    "drop_last": cfg["TRAIN"]["drop_last"],
}

由于玩具数据集无法直接使用PaddleScience内置的dataset API,所以先把所有数据生成出来,再使用ppsci.data.dataset.array_dataset.NamedArrayDataset

下面给出玩具数据集的生成代码

def load_toy_data(input_keys, mode):
    data = []
    if mode == "25gaussians":
        for i in range(100000 // 25):
            for x in range(-2, 3):
                for y in range(-2, 3):
                    point = np.random.randn(2) * 0.05
                    point[0] += 2 * x
                    point[1] += 2 * y
                    data.append(point)
        data = np.array(data, dtype="float32")
        np.random.shuffle(data)
        data /= 2.828  # stdev
    elif mode == "swissroll":
        data = make_swiss_roll(n_samples=100000, noise=0.25)[0]
        data = data.astype("float32")[:, [0, 2]]
        data /= 7.5  # stdev plus a little

    elif mode == "8gaussians":
        scale = 2.0
        centers = [
            (1, 0),
            (-1, 0),
            (0, 1),
            (0, -1),
            (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
            (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
            (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
            (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
        ]
        centers = [(scale * x, scale * y) for x, y in centers]
        data = []
        for i in range(100000 // 8):
            point = np.random.randn(2) * 0.02
            center = random.choice(centers)
            point[0] += center[0]
            point[1] += center[1]
            data.append(point)
        data = np.array(data, dtype="float32")
        data /= 1.414  # stdev
    data = {input_keys[0]: data}
    return data

下面给出dataloader的配置代码:

inputs = load_toy_data(**cfg["DATA"])
dataloader_cfg = {
    "dataset": {
        "name": cfg["EVAL"]["dataset"]["name"],
        "input": inputs,
    },
    "sampler": {
        **cfg["TRAIN"]["sampler"],
    },
    "batch_size": cfg["TRAIN"]["batch_size"],
    "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
    "num_workers": cfg["TRAIN"]["num_workers"],
    "drop_last": cfg["TRAIN"]["drop_last"],
}

3.3 模型构建

本案例的WGAN-GP没有被内置在PaddleScience中,需要额外实现,因此我们自定义了WganGpCifar10GeneratorWganGpCifar10DiscriminatorWganGpMnistGeneratorWganGpMnistDiscriminatorWganGpToyGeneratorWganGpToyDiscriminator

模型的构建代码如下:

WganGpCifar10GeneratorWganGpCifar10Discriminator

generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])

WganGpMnistGeneratorWganGpMnistDiscriminator

generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])

WganGpToyGeneratorWganGpToyDiscriminator

generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])
discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])

参数配置如下:

WganGpCifar10GeneratorWganGpCifar10Discriminator

MODEL:
  gen_net:
    input_keys: [ "labels" ]
    output_keys: [ "fake_data" ]
    dim: 128
    output_dim: 3072
    label_num: 10
    use_label: true
  dis_net:
    input_keys: [ "data", "labels" ]
    output_keys: [ "disc_fake", "disc_acgan" ]
    dim: 128
    label_num: 10
    use_label: true

WganGpMnistGeneratorWganGpMnistDiscriminator

MODEL:
  gen_net:
    output_keys: [ "fake_data" ]
    dim: 64
    output_dim: 784
  dis_net:
    input_keys: [ "data" ]
    output_keys: [ "score" ]
    dim: 64

WganGpToyGeneratorWganGpToyDiscriminator

MODEL:
  gen_net:
    output_keys: [ "fake_data" ]
    dim: 512
  dis_net:
    input_keys: [ "data" ]
    output_keys: [ "score" ]
    dim: 512

3.4 自定义loss

WGAN-GP的损失函数较复杂,需要我们自定义实现。PaddleScience提供了用于自定loss函数的API——ppsci.loss.FunctionalLoss。方法为先定义loss函数,再将函数名作为参数传给 FunctionalLoss。需要注意,自定义loss函数的输入输出需要是字典的格式。

3.4.1 Generator的loss

Cifar10_Generator的loss包含了对抗性损失和分类损失。这两项loss都有对应的权重,如果某一项 loss 的权重为 0,则表示训练中不添加该 loss 项。

class Cifar10GenFuncs:
    """
    Loss function for cifar10 generator
    Args
        discriminator_model: discriminator model
        acgan_scale_g: scale of acgan loss for generator

    """

    def __init__(
        self,
        discriminator_model,
        acgan_scale_g=0.1,
    ):
        self.crossEntropyLoss = paddle.nn.CrossEntropyLoss()
        self.acgan_scale_g = acgan_scale_g
        self.discriminator_model = discriminator_model

    def loss(self, output_dict: Dict, *args):
        fake_image = output_dict["fake_data"]
        labels = output_dict["labels"]
        outputs = self.discriminator_model({"data": fake_image, "labels": labels})
        disc_fake, disc_fake_acgan = outputs["disc_fake"], outputs["disc_acgan"]
        gen_cost = -paddle.mean(disc_fake)
        if disc_fake_acgan is not None:
            gen_acgan_cost = self.crossEntropyLoss(disc_fake_acgan, labels)
            gen_cost += self.acgan_scale_g * gen_acgan_cost
        return {"loss_g": gen_cost}

MNIST_Generator的loss只包含了对抗性损失。

class MnistGenFuncs:
    """
    Loss function for mnist generator
    Args
        discriminator_model: discriminator model
    """

    def __init__(self, discriminator_model):
        self.discriminator_model = discriminator_model

    def loss(self, output_dict: Dict, *args):
        fake_data = output_dict["fake_data"]
        score = self.discriminator_model({"data": fake_data})["score"]
        gen_cost = -paddle.mean(score)
        return {"loss_g": gen_cost}

Toy_Generator的loss只包含了对抗性损失。

class ToyGenFuncs:
    """
    Loss function for toy generator
    Args
        discriminator_model: discriminator model
    """

    def __init__(self, discriminator_model):
        self.discriminator_model = discriminator_model

    def loss(self, output_dict: Dict, *args):
        fake_data = output_dict["fake_data"]
        outputs = self.discriminator_model({"data": fake_data})
        disc_fake = outputs["score"]
        gen_cost = -paddle.mean(disc_fake)
        return {"loss_g": gen_cost}

3.4.2 Discriminator的loss

Cifar10_Discriminator的loss包含了Wasserstein损失和梯度惩罚以及分类损失。其中,只有分类损失项有权重参数。

class Cifar10DisFuncs:
    """
    Loss function for cifar10 discriminator
    Args
        discriminator_model: discriminator model
        acgan_scale: scale of acgan loss for discriminator

    """

    def __init__(self, discriminator_model, acgan_scale):
        self.crossEntropyLoss = paddle.nn.CrossEntropyLoss()
        self.acgan_scale = acgan_scale
        self.discriminator_model = discriminator_model

    def loss(self, output_dict: Dict, label_dict: Dict, *args):
        fake_image = output_dict["fake_data"]
        real_image = label_dict["real_data"]
        labels = output_dict["labels"]
        disc_fake = self.discriminator_model({"data": fake_image, "labels": labels})[
            "disc_fake"
        ]
        out = self.discriminator_model({"data": real_image, "labels": labels})
        disc_real, disc_real_acgan = out["disc_fake"], out["disc_acgan"]
        gradient_penalty = self.compute_gradient_penalty(real_image, fake_image, labels)
        disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
        disc_wgan = disc_cost + gradient_penalty
        if disc_real_acgan is not None:
            disc_acgan_cost = self.crossEntropyLoss(disc_real_acgan, labels)
            disc_acgan = disc_acgan_cost.sum()
            disc_cost = disc_wgan + (self.acgan_scale * disc_acgan)
        else:
            disc_cost = disc_wgan
        return {"loss_d": disc_cost}

    def compute_gradient_penalty(self, real_data, fake_data, labels):
        differences = fake_data - real_data
        alpha = paddle.rand([fake_data.shape[0], 1])
        interpolates = real_data + (alpha * differences)
        gradients = paddle.grad(
            outputs=self.discriminator_model({"data": interpolates, "labels": labels})[
                "disc_fake"
            ],
            inputs=interpolates,
            create_graph=True,
            retain_graph=False,
        )[0]
        slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
        gradient_penalty = 10 * paddle.mean((slopes - 1.0) ** 2)
        return gradient_penalty

MNIST_Discriminator的loss包含了Wasserstein损失和梯度惩罚。

class MnistDisFuncs:
    """
    Loss function for mnist discriminator
    Args
        discriminator_model: discriminator model
        lamda: gradient penalty coefficient
    """

    def __init__(self, discriminator_model, lamda):
        self.discriminator_model = discriminator_model
        self.lamda = lamda

    def loss(self, output_dict: Dict, *args):
        real_data = output_dict["real_data"]
        fake_data = output_dict["fake_data"]
        disc_fake = self.discriminator_model({"data": fake_data})["score"]
        disc_real = self.discriminator_model({"data": real_data})["score"]
        gradient_penalty = self.compute_gradient_penalty(real_data, fake_data)
        disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
        disc_cost = disc_cost + gradient_penalty
        loss = disc_cost
        return {"loss_d": loss}

    def compute_gradient_penalty(self, real_data, fake_data):
        differences = fake_data - real_data
        alpha = paddle.rand([fake_data.shape[0], 1])
        interpolates = real_data + (alpha * differences)
        gradients = paddle.grad(
            outputs=self.discriminator_model({"data": interpolates})["score"],
            inputs=interpolates,
            create_graph=True,
            retain_graph=False,
        )[0]
        slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
        gradient_penalty = self.lamda * paddle.mean((slopes - 1.0) ** 2)
        return gradient_penalty

Toy_Discriminator的loss包含了Wasserstein损失和梯度惩罚。

class ToyDisFuncs:
    """
    Loss function for toy discriminator
    Args
        discriminator_model: discriminator model
        lamda: gradient penalty coefficient
    """

    def __init__(self, discriminator_model, lamda):
        self.discriminator_model = discriminator_model
        self.lamda = lamda

    def loss(self, output_dict: Dict, *args):
        real_data = output_dict["real_data"]
        fake_data = output_dict["fake_data"]
        disc_fake = self.discriminator_model({"data": fake_data})["score"]
        disc_real = self.discriminator_model({"data": real_data})["score"]
        gradient_penalty = self.compute_gradient_penalty(real_data, fake_data)
        disc_cost = paddle.mean(disc_fake) - paddle.mean(disc_real)
        disc_cost = disc_cost + gradient_penalty
        loss = disc_cost
        return {"loss_d": loss}

    def compute_gradient_penalty(self, real_data, fake_data):
        differences = fake_data - real_data
        alpha = paddle.rand([fake_data.shape[0], 1])
        interpolates = real_data + (alpha * differences)
        gradients = paddle.grad(
            outputs=self.discriminator_model({"data": interpolates})["score"],
            inputs=interpolates,
            create_graph=True,
            retain_graph=False,
        )[0]
        slopes = paddle.sqrt(paddle.sum(paddle.square(gradients), axis=1))
        gradient_penalty = self.lamda * paddle.mean((slopes - 1.0) ** 2)
        return gradient_penalty

3.5 约束构建

所有案例均使用ppsci.constraint.SupervisedConstraint构建约束。

构建代码如下:

针对Cifar10的实验

constraint_generator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    output_expr={"labels": lambda out: out["labels"]},
    name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}

constraint_discriminator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
    output_expr={"labels": lambda out: out["labels"]},
    name="constraint_discriminator",
)
constraint_discriminator_dict = {
    constraint_discriminator.name: constraint_discriminator
}

针对MNIST的实验

constraint_generator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}

constraint_discriminator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
    output_expr={"real_data": lambda out: out["real_data"]},
    name="constraint_discriminator",
)
constraint_discriminator_dict = {
    constraint_discriminator.name: constraint_discriminator
}

针对玩具数据集的实验

constraint_generator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    name="constraint_generator",
)
constraint_generator_dict = {constraint_generator.name: constraint_generator}

constraint_discriminator = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
    output_expr={"real_data": lambda out: out["real_data"]},
    name="constraint_discriminator",
)
constraint_discriminator_dict = {
    constraint_discriminator.name: constraint_discriminator
}

3.6 优化器构建

WGANGP使用Adam优化器,可直接调用ppsci.optimizer.Adam构建,代码如下:

针对Cifar10的实验

lr_scheduler_generator = Linear(**cfg["TRAIN"]["lr_scheduler_gen"])()
lr_scheduler_discriminator = Linear(**cfg["TRAIN"]["lr_scheduler_dis"])()

optimizer_generator = ppsci.optimizer.Adam(
    learning_rate=lr_scheduler_generator,
    beta1=cfg["TRAIN"]["optimizer"]["beta1"],
    beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_discriminator = ppsci.optimizer.Adam(
    learning_rate=lr_scheduler_discriminator,
    beta1=cfg["TRAIN"]["optimizer"]["beta1"],
    beta2=cfg["TRAIN"]["optimizer"]["beta2"],
)
optimizer_generator = optimizer_generator(generator_model)
optimizer_discriminator = optimizer_discriminator(discriminator_model)

针对MNIST的实验

optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])
optimizer_generator = optimizer(generator_model)
optimizer_discriminator = optimizer(discriminator_model)

针对玩具数据集的实验

optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])

optimizer_generator = optimizer(generator_model)
optimizer_discriminator = optimizer(discriminator_model)

3.7 Solver构建

将构建好的模型、约束、优化器和其它参数传递给 ppsci.solver.Solver

针对Cifar10的实验

solver_generator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "generator"),
    constraint=constraint_generator_dict,
    optimizer=optimizer_generator,
    epochs=cfg.TRAIN.epochs_gen,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "discriminator"),
    constraint=constraint_discriminator_dict,
    optimizer=optimizer_discriminator,
    epochs=cfg.TRAIN.epochs_dis,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)

针对MNIST的实验

solver_generator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "generator"),
    constraint=constraint_generator_dict,
    optimizer=optimizer_generator,
    epochs=cfg.TRAIN.epochs_gen,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "discriminator"),
    constraint=constraint_discriminator_dict,
    optimizer=optimizer_discriminator,
    epochs=cfg.TRAIN.epochs_dis,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)

针对玩具数据集的实验

solver_generator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "generator"),
    constraint=constraint_generator_dict,
    optimizer=optimizer_generator,
    epochs=cfg.TRAIN.epochs_gen,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)
solver_discriminator = ppsci.solver.Solver(
    model=generator_model,
    output_dir=os.path.join(cfg.output_dir, "discriminator"),
    constraint=constraint_discriminator_dict,
    optimizer=optimizer_discriminator,
    epochs=cfg.TRAIN.epochs_dis,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
    pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
)

3.8 模型训练

针对Cifar10的实验

for i in range(cfg.TRAIN.epochs):
    logger.message(f"\nEpoch: {i + 1}\n")
    optimizer_discriminator.clear_grad()
    solver_discriminator.train()
    optimizer_generator.clear_grad()
    solver_generator.train()

针对MNIST的实验

for i in range(cfg.TRAIN.epochs):
    logger.message(f"\nEpoch: {i + 1}\n")
    optimizer_discriminator.clear_grad()
    solver_discriminator.train()
    optimizer_generator.clear_grad()
    solver_generator.train()

针对玩具数据集的实验

for i in range(cfg.TRAIN.epochs):
    logger.message(f"\nEpoch: {i + 1}\n")
    optimizer_discriminator.clear_grad()
    solver_discriminator.train()
    optimizer_generator.clear_grad()
    solver_generator.train()

3.9 自定义metric

案例中只有针对Cifar10的案例有评估指标为Inception Score,MNIST和Toy案例没有评估指标。由于metric为空会报错所以自定义了一个无效metric

所以我们额外实现了两个metric

PaddleScience提供了用于自定metric函数的API——ppsci.metric.FunctionalMetric。方法为先定义metric函数,再将函数名作为参数传给 FunctionalMetric。需要注意,自定义metric函数的输入输出需要是字典的格式。

Inception Score的实现代码如下:

class InceptionScore:
    """
    Inception Score
    Args
        eps: epsilon to avoid log(0)
        splits: number of splits
    """

    def __init__(self, eps=1e-16, splits=10, batch_size=64):
        self.inception_v3 = paddle.vision.inception_v3(pretrained=True)
        self.inception_v3.fc.bias.set_value(
            paddle.to_tensor(np.zeros(self.inception_v3.fc.bias.shape, dtype="float32"))
        )
        self.inception_v3.eval()
        self.eps = eps
        self.splits = splits
        self.softmax = paddle.nn.Softmax(axis=1)
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def inception_score(self, output_dict: Dict, label_dict, *args):
        with paddle.no_grad():
            images = output_dict["fake_data"]
            images = images.reshape((-1, 3, 32, 32))
            images = (images + 1.0) * (255.99 / 2)
            predict = []
            for i in range(images.shape[0] // self.batch_size):
                image = images[i * self.batch_size : (i + 1) * self.batch_size]
                image = F.interpolate(image, size=(299, 299), mode="bilinear")
                image = image / 255
                image = self.transform(image)
                predict.append(self.inception_v3(image))
            else:
                image = images[(images.shape[0] // self.batch_size) * self.batch_size :]
                if image.shape[0] != 0:
                    image = F.interpolate(image, size=(299, 299), mode="bilinear")
                    image = image / 255
                    image = self.transform(image)
                    predict.append(self.inception_v3(image))
            predict = paddle.concat(predict, axis=0)
            predict = self.softmax(predict) + self.eps
            scores = []
            split_size = predict.shape[0] // self.splits
            for i in range(self.splits):
                part = predict[i * split_size : (i + 1) * split_size]
                kl = part * (paddle.log(part) - paddle.log(paddle.mean(part, 0)))
                kl = paddle.mean(paddle.sum(kl, 1))
                scores.append(paddle.exp(kl))
            scores = paddle.to_tensor(scores)
            return {"inception_score": paddle.mean(scores)}

invalid_metric的代码如下

def invalid_metric(*args, **kwargs):
    return {"invalid_metric": 0}

3.10 Validator构建

本案例使用ppsci.validate.SupervisedValidator构建评估器。

针对Cifar10的实验

validator = ppsci.validate.SupervisedValidator(
    dataloader_cfg=valid_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    output_expr={"labels": lambda out: out["labels"]},
    metric={
        "IS": ppsci.metric.FunctionalMetric(eval_inception_score.inception_score),
    },
    name="val",
)
validator_dict = {validator.name: validator}

针对MNIST的实验

validator = ppsci.validate.SupervisedValidator(
    dataloader_cfg=valid_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    metric={
        "MAE": ppsci.metric.FunctionalMetric(invalid_metric),
    },
    name="val",
)
validator_dict = {validator.name: validator}

针对玩具数据集的实验

validator = ppsci.validate.SupervisedValidator(
    dataloader_cfg=valid_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
    metric={"invalid_metric": ppsci.metric.FunctionalMetric(invalid_metric)},
    name="val",
)
validator_dict = {validator.name: validator}

3.11 模型评估

将模型、评估器和权重路径传递给ppsci.solver.Solver后,通过solver.eval()启动评估。

针对Cifar10的实验

solver = ppsci.solver.Solver(
    model=generator_model,
    validator=validator_dict,
    pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
    output_dir=cfg.output_dir,
)

# evaluation
solver.eval()

针对MNIST的实验

# initialize solver
solver = ppsci.solver.Solver(
    model=generator_model,
    validator=validator_dict,
    pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
    output_dir=cfg.output_dir,
)

# evaluation
solver.eval()

针对玩具数据集的实验

solver = ppsci.solver.Solver(
    model=generator_model,
    validator=validator_dict,
    pretrained_model_path=cfg.EVAL.pretrained_model_path,
    output_dir=cfg.output_dir,
)

# eval
solver.eval()

3.12 可视化

评估完成后,我们以图片的形式对结果进行可视化,代码如下:

针对Cifar10的实验

if cfg.VIS.vis:
    with solver.no_grad_context_manager(True):
        generator_model.eval()
        for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
            if batch_idx + 1 > cfg.VIS.batch:
                break
            fake_image = generator_model(input_)["fake_data"]
            show_save_image(
                fake_image[0],
                f"{cfg.output_dir}/image{batch_idx}.png",
            )
    print(f"The visualizations are saved to {cfg.output_dir}")

针对MNIST的实验

# visualization
if cfg.VIS.vis:
    with solver.no_grad_context_manager(True):
        for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
            if batch_idx + 1 > cfg.VIS.batch:
                break
            fake_data = generator_model(input_)["fake_data"]
            show_mnist(
                fake_data[0],
                f"{cfg.output_dir}/image{batch_idx}.png",
            )
            show_mnist(
                input_["real_data"][0],
                f"{cfg.output_dir}/image_real_{batch_idx}.png",
            )
    print(f"The visualizations are saved to {cfg.output_dir}")

针对玩具数据集的实验

# visualization
if cfg.VIS.vis:
    with solver.no_grad_context_manager(True):
        input_, _, _ = next(iter(validator.data_loader))
        real_data = input_["real_data"]
        generate_toy_image(
            true_dist=real_data,
            discriminator=discriminator_model,
            path=os.path.join(cfg.output_dir, "image.png"),
        )
    print(f"The visualizations are saved to {cfg.output_dir}")

4. 完整代码

针对Cifar10的实验

import os
import platform

import hydra
import paddle
from functions import Cifar10DisFuncs
from functions import Cifar10GenFuncs
from functions import InceptionScore
from functions import load_cifar10
from functions import show_save_image
from omegaconf import DictConfig
from wgangp_cifar10_model import WganGpCifar10Discriminator
from wgangp_cifar10_model import WganGpCifar10Generator

import ppsci
from ppsci.optimizer.lr_scheduler import Linear
from ppsci.utils import logger

os.environ["FLAGS_cudnn_deterministic"] = "1"


def evaluate(cfg: DictConfig):
    # set model
    generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
    if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
        cfg.EVAL.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))

    # set Loss
    generator_funcs = Cifar10GenFuncs(
        **cfg["LOSS"]["gen"], discriminator_model=discriminator_model
    )
    eval_inception_score = InceptionScore(**cfg["EVAL"]["inceptionscore"])

    # set data
    inputs, labels = load_cifar10(**cfg["DATA"])
    valid_dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
            "label": labels,
        },
        "batch_size": cfg["EVAL"]["batch_size"],
        "use_shared_memory": cfg["EVAL"]["use_shared_memory"],
        "num_workers": cfg["EVAL"]["num_workers"]
        if platform.system() != "Windows"
        else 0,
    }

    # set validator
    validator = ppsci.validate.SupervisedValidator(
        dataloader_cfg=valid_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        output_expr={"labels": lambda out: out["labels"]},
        metric={
            "IS": ppsci.metric.FunctionalMetric(eval_inception_score.inception_score),
        },
        name="val",
    )
    validator_dict = {validator.name: validator}

    # initialize solver
    solver = ppsci.solver.Solver(
        model=generator_model,
        validator=validator_dict,
        pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
        output_dir=cfg.output_dir,
    )

    # evaluation
    solver.eval()

    # visualization
    if cfg.VIS.vis:
        with solver.no_grad_context_manager(True):
            generator_model.eval()
            for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
                if batch_idx + 1 > cfg.VIS.batch:
                    break
                fake_image = generator_model(input_)["fake_data"]
                show_save_image(
                    fake_image[0],
                    f"{cfg.output_dir}/image{batch_idx}.png",
                )
        print(f"The visualizations are saved to {cfg.output_dir}")


def train(cfg: DictConfig):
    # set model
    generator_model = WganGpCifar10Generator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpCifar10Discriminator(**cfg["MODEL"]["dis_net"])
    if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
        cfg.TRAIN.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))

    # set Loss
    generator_funcs = Cifar10GenFuncs(
        **cfg["LOSS"]["gen"], discriminator_model=discriminator_model
    )
    discriminator_funcs = Cifar10DisFuncs(
        **cfg["LOSS"]["dis"], discriminator_model=discriminator_model
    )

    # set dataloader
    inputs, labels = load_cifar10(**cfg["DATA"])
    dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
            "label": labels,
        },
        "sampler": {
            **cfg["TRAIN"]["sampler"],
        },
        "batch_size": cfg["TRAIN"]["batch_size"],
        "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
        "num_workers": cfg["TRAIN"]["num_workers"],
        "drop_last": cfg["TRAIN"]["drop_last"],
    }

    # set constraint
    constraint_generator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        output_expr={"labels": lambda out: out["labels"]},
        name="constraint_generator",
    )
    constraint_generator_dict = {constraint_generator.name: constraint_generator}

    constraint_discriminator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
        output_expr={"labels": lambda out: out["labels"]},
        name="constraint_discriminator",
    )
    constraint_discriminator_dict = {
        constraint_discriminator.name: constraint_discriminator
    }

    # set optimizer
    lr_scheduler_generator = Linear(**cfg["TRAIN"]["lr_scheduler_gen"])()
    lr_scheduler_discriminator = Linear(**cfg["TRAIN"]["lr_scheduler_dis"])()

    optimizer_generator = ppsci.optimizer.Adam(
        learning_rate=lr_scheduler_generator,
        beta1=cfg["TRAIN"]["optimizer"]["beta1"],
        beta2=cfg["TRAIN"]["optimizer"]["beta2"],
    )
    optimizer_discriminator = ppsci.optimizer.Adam(
        learning_rate=lr_scheduler_discriminator,
        beta1=cfg["TRAIN"]["optimizer"]["beta1"],
        beta2=cfg["TRAIN"]["optimizer"]["beta2"],
    )
    optimizer_generator = optimizer_generator(generator_model)
    optimizer_discriminator = optimizer_discriminator(discriminator_model)

    # initialize solver
    solver_generator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "generator"),
        constraint=constraint_generator_dict,
        optimizer=optimizer_generator,
        epochs=cfg.TRAIN.epochs_gen,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )
    solver_discriminator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "discriminator"),
        constraint=constraint_discriminator_dict,
        optimizer=optimizer_discriminator,
        epochs=cfg.TRAIN.epochs_dis,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )

    # train
    for i in range(cfg.TRAIN.epochs):
        logger.message(f"\nEpoch: {i + 1}\n")
        optimizer_discriminator.clear_grad()
        solver_discriminator.train()
        optimizer_generator.clear_grad()
        solver_generator.train()

    # save model weight
    paddle.save(
        generator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_generator.pdparams"),
    )
    paddle.save(
        discriminator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
    )


@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_cifar10.yaml")
def main(cfg: DictConfig):
    ppsci.utils.misc.set_random_seed(cfg["seed"])
    logger.init_logger(
        cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
    )
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

针对MNIST的实验

import os
import platform

import hydra
import paddle
from functions import MnistDisFuncs
from functions import MnistGenFuncs
from functions import invalid_metric
from functions import load_mnist
from functions import show_mnist
from omegaconf import DictConfig
from wgangp_mnist_model import WganGpMnistDiscriminator
from wgangp_mnist_model import WganGpMnistGenerator

import ppsci
from ppsci.utils import logger


def evaluate(cfg: DictConfig):
    # set model
    generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])
    if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
        cfg.EVAL.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))

    # set Loss
    generator_funcs = MnistGenFuncs(discriminator_model=discriminator_model)

    # set dataloader
    inputs = load_mnist(**cfg["DATA"])
    valid_dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
        },
        "batch_size": cfg["EVAL"]["batch_size"],
        "use_shared_memory": cfg["EVAL"]["use_shared_memory"],
        "num_workers": cfg["EVAL"]["num_workers"]
        if platform.system() != "Windows"
        else 0,
    }

    # set validator
    validator = ppsci.validate.SupervisedValidator(
        dataloader_cfg=valid_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        metric={
            "MAE": ppsci.metric.FunctionalMetric(invalid_metric),
        },
        name="val",
    )
    validator_dict = {validator.name: validator}

    # initialize solver
    solver = ppsci.solver.Solver(
        model=generator_model,
        validator=validator_dict,
        pretrained_model_path=cfg.EVAL.pretrained_gen_model_path,
        output_dir=cfg.output_dir,
    )

    # evaluation
    solver.eval()

    # visualization
    if cfg.VIS.vis:
        with solver.no_grad_context_manager(True):
            for batch_idx, (input_, _, _) in enumerate(validator.data_loader):
                if batch_idx + 1 > cfg.VIS.batch:
                    break
                fake_data = generator_model(input_)["fake_data"]
                show_mnist(
                    fake_data[0],
                    f"{cfg.output_dir}/image{batch_idx}.png",
                )
                show_mnist(
                    input_["real_data"][0],
                    f"{cfg.output_dir}/image_real_{batch_idx}.png",
                )
        print(f"The visualizations are saved to {cfg.output_dir}")


def train(cfg: DictConfig):
    # set model
    generator_model = WganGpMnistGenerator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpMnistDiscriminator(**cfg["MODEL"]["dis_net"])
    if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
        cfg.TRAIN.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))

    # set Loss
    generator_funcs = MnistGenFuncs(discriminator_model=discriminator_model)
    discriminator_funcs = MnistDisFuncs(
        **cfg["LOSS"]["dis"], discriminator_model=discriminator_model
    )

    # set dataloader
    inputs = load_mnist(**cfg["DATA"])
    dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
        },
        "sampler": {
            **cfg["TRAIN"]["sampler"],
        },
        "batch_size": cfg["TRAIN"]["batch_size"],
        "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
        "num_workers": cfg["TRAIN"]["num_workers"],
        "drop_last": cfg["TRAIN"]["drop_last"],
    }

    # set constraint
    constraint_generator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        name="constraint_generator",
    )
    constraint_generator_dict = {constraint_generator.name: constraint_generator}

    constraint_discriminator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
        output_expr={"real_data": lambda out: out["real_data"]},
        name="constraint_discriminator",
    )
    constraint_discriminator_dict = {
        constraint_discriminator.name: constraint_discriminator
    }

    # set optimizer
    optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])
    optimizer_generator = optimizer(generator_model)
    optimizer_discriminator = optimizer(discriminator_model)

    # initialize solver
    solver_generator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "generator"),
        constraint=constraint_generator_dict,
        optimizer=optimizer_generator,
        epochs=cfg.TRAIN.epochs_gen,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )
    solver_discriminator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "discriminator"),
        constraint=constraint_discriminator_dict,
        optimizer=optimizer_discriminator,
        epochs=cfg.TRAIN.epochs_dis,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )

    # train
    for i in range(cfg.TRAIN.epochs):
        logger.message(f"\nEpoch: {i + 1}\n")
        optimizer_discriminator.clear_grad()
        solver_discriminator.train()
        optimizer_generator.clear_grad()
        solver_generator.train()

    # save model weight
    paddle.save(
        generator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_generator.pdparams"),
    )
    paddle.save(
        discriminator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
    )


@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_mnist.yaml")
def main(cfg: DictConfig):
    ppsci.utils.misc.set_random_seed(cfg["seed"])
    logger.init_logger(
        cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
    )
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

针对玩具数据集的实验

import os
import platform

import hydra
import paddle
from functions import ToyDisFuncs
from functions import ToyGenFuncs
from functions import generate_toy_image
from functions import invalid_metric
from functions import load_toy_data
from omegaconf import DictConfig
from wgangp_toy_model import WganGpToyDiscriminator
from wgangp_toy_model import WganGpToyGenerator

import ppsci
from ppsci.utils import logger


def evaluate(cfg: DictConfig):
    # set model
    discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])
    if cfg.EVAL.pretrained_dis_model_path and os.path.exists(
        cfg.EVAL.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.EVAL.pretrained_dis_model_path))
    generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])

    # set Loss
    generator_funcs = ToyGenFuncs(discriminator_model=discriminator_model)

    # set dataloader
    inputs = load_toy_data(**cfg["DATA"])
    valid_dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
        },
        "batch_size": cfg["EVAL"]["batch_size"],
        "use_shared_memory": cfg["EVAL"]["use_shared_memory"],
        "num_workers": cfg["EVAL"]["num_workers"]
        if platform.system() != "Windows"
        else 0,
    }

    # set validator
    validator = ppsci.validate.SupervisedValidator(
        dataloader_cfg=valid_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        metric={"invalid_metric": ppsci.metric.FunctionalMetric(invalid_metric)},
        name="val",
    )
    validator_dict = {validator.name: validator}

    # initialize solver
    solver = ppsci.solver.Solver(
        model=generator_model,
        validator=validator_dict,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        output_dir=cfg.output_dir,
    )

    # eval
    solver.eval()

    # visualization
    if cfg.VIS.vis:
        with solver.no_grad_context_manager(True):
            input_, _, _ = next(iter(validator.data_loader))
            real_data = input_["real_data"]
            generate_toy_image(
                true_dist=real_data,
                discriminator=discriminator_model,
                path=os.path.join(cfg.output_dir, "image.png"),
            )
        print(f"The visualizations are saved to {cfg.output_dir}")


def train(cfg: DictConfig):
    # set model
    generator_model = WganGpToyGenerator(**cfg["MODEL"]["gen_net"])
    discriminator_model = WganGpToyDiscriminator(**cfg["MODEL"]["dis_net"])
    if cfg.TRAIN.pretrained_dis_model_path and os.path.exists(
        cfg.TRAIN.pretrained_dis_model_path
    ):
        discriminator_model.load_dict(paddle.load(cfg.TRAIN.pretrained_dis_model_path))

    # set Loss
    generator_funcs = ToyGenFuncs(discriminator_model=discriminator_model)
    discriminator_funcs = ToyDisFuncs(
        **cfg["LOSS"]["dis"], discriminator_model=discriminator_model
    )

    # set dataloader
    inputs = load_toy_data(**cfg["DATA"])
    dataloader_cfg = {
        "dataset": {
            "name": cfg["EVAL"]["dataset"]["name"],
            "input": inputs,
        },
        "sampler": {
            **cfg["TRAIN"]["sampler"],
        },
        "batch_size": cfg["TRAIN"]["batch_size"],
        "use_shared_memory": cfg["TRAIN"]["use_shared_memory"],
        "num_workers": cfg["TRAIN"]["num_workers"],
        "drop_last": cfg["TRAIN"]["drop_last"],
    }

    # set constraint
    constraint_generator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(generator_funcs.loss),
        name="constraint_generator",
    )
    constraint_generator_dict = {constraint_generator.name: constraint_generator}

    constraint_discriminator = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(discriminator_funcs.loss),
        output_expr={"real_data": lambda out: out["real_data"]},
        name="constraint_discriminator",
    )
    constraint_discriminator_dict = {
        constraint_discriminator.name: constraint_discriminator
    }

    # set optimizer
    optimizer = ppsci.optimizer.Adam(**cfg["TRAIN"]["optimizer"])

    optimizer_generator = optimizer(generator_model)
    optimizer_discriminator = optimizer(discriminator_model)

    # initialize solver
    solver_generator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "generator"),
        constraint=constraint_generator_dict,
        optimizer=optimizer_generator,
        epochs=cfg.TRAIN.epochs_gen,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )
    solver_discriminator = ppsci.solver.Solver(
        model=generator_model,
        output_dir=os.path.join(cfg.output_dir, "discriminator"),
        constraint=constraint_discriminator_dict,
        optimizer=optimizer_discriminator,
        epochs=cfg.TRAIN.epochs_dis,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
        pretrained_model_path=cfg.TRAIN.pretrained_gen_model_path,
    )

    # train
    for i in range(cfg.TRAIN.epochs):
        logger.message(f"\nEpoch: {i + 1}\n")
        optimizer_discriminator.clear_grad()
        solver_discriminator.train()
        optimizer_generator.clear_grad()
        solver_generator.train()

    # save model weight
    paddle.save(
        generator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_generator.pdparams"),
    )
    paddle.save(
        discriminator_model.state_dict(),
        os.path.join(cfg.output_dir, "model_discriminator.pdparams"),
    )


@hydra.main(version_base=None, config_path="./conf", config_name="wgangp_toy.yaml")
def main(cfg: DictConfig):
    ppsci.utils.misc.set_random_seed(cfg["seed"])
    logger.init_logger(
        cfg.LOGGER.name, log_file=os.path.join(cfg.output_dir, cfg.LOGGER.log_file)
    )
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

6. 参考文献