跳转至

VelocityGAN

Note

  1. 运行之前,建议快速了解一下数据集数据读取方式
  2. OpenFWI数据集下载到FWIOpenData目录中对应的子目录(如Flatvel_A)。
  3. 将yaml配置文件中的anno参数与数据集对应。
python velocityGAN.py
python velocityGAN.py model=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/velocitygan/velocitygan_pretrained.pdparams
预训练模型 指标
velocitygan_pretrained.pdparams MAE: 0.0669
RMSE: 0.0947
SSIM: 0.8511

1. 背景简介

地下速度图像在地球科学领域具有重要作用。它反映了地震波在地下各个区域的传播速度,为探测地球内部结构提供了关键信息。地震波形反演方法被广泛应用于重构地下速度成像。传统的物理驱动的求解方法是一个数值优化过程,需要经历多次迭代并求解波动方程。这不仅计算成本高,而且通常只能达到局部最优解,导致图像精度有限。基于数据驱动的深度学习方法可以减轻这些问题,在更短的时间内生成更高精度的速度图像。

VelocityGAN就是一个具体的例子。它是一个端到端的框架,能够直接从原始地震波形数据生成高质量的速度图像。论文表明,VelocityGAN 超过了传统的物理驱动波形反演方法,并在数据驱动的基准测试中达到了SOTA的性能。

2. 模型原理

作为一种数据驱动的深度学习方法,VelocityGAN可以直接学习波形数据到速度图像的映射关系,而无需求解波动方程。本段落仅简单介绍模型原理,具体细节请阅读VelocityGAN: Data-Driven Full-Waveform Inversion Using Conditional Adversarial Networks

2.1 模型结构

VelocityGAN是一个条件对抗网络,包含了一个image-to-image的生成器和一个CNN的判别器。下图显示了模型的整体结构。

velocityGAN

  • Generator是一种Encoder-Decoder结构的卷积神经网络。Encoder从地震波形数据中提取特征,并逐步将其压缩成潜在向量(latent vector);Decoder则根据这个潜在向量推算出相应的速度图。

  • Discriminator是由9层卷积块组成的模型。输入速度图像,输出图像的真实性分数。

2.2 损失函数

判别器的损失函数采用了Wasserstein损失和梯度惩罚。其表达式为:

\[ L_d = \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}} D(\tilde{x}) - \underset{x \sim \mathbb{P}_r}{\mathbb{E}}D(x) + \lambda \underset{\hat{x} \sim \mathbb{P}_{\hat{x}}}{\mathbb{E}} \left[ \left( \| \nabla_{\hat{x}} D(\hat{x}) \|_2 - 1 \right)^2 \right] \]

其中\(\mathbb{P}_g\)是生成器的分布,\(\mathbb{P}_r\)是真实数据的分布,\(\mathbb{P}_{\hat{x}}\)是来自\(\mathbb{P}_g\)\(\mathbb{P}_r\)的混合插值样本。

生成器的损失函数是对抗性损失[\(- \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x})\)]和内容损失(MAE、MSE)的组合。其表达式为:

\[ L_g = - \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x}) + \frac{\lambda_1}{w\cdot h} \sum_{i=1}^{w} \sum_{j=1}^{h} \left| \tilde{v}(i,j) - v(i,j) \right| + \frac{\lambda_2}{w\cdot h}\sum_{i=1}^{w} \sum_{j=1}^{h} \left( \tilde{v}(i,j) - v(i,j) \right)^2 \]

其中,\(w\)\(h\)分别为速度图的宽和高,\(v(\cdot)\)\(\tilde{v}(\cdot)\)分别表示速度图的真实像素值和预测像素值。\(\lambda_1\)\(\lambda_2\)为超参数,用于调节两项损失的相对重要性。

3. 模型构建

接下来开始讲解如何使用PaddleScience框架实现VelocityGAN。以下内容仅对关键步骤进行阐述,其余细节请参考 API文档

3.1 数据集介绍

数据集采用了SMILE Team开源的OpenFWI数据集。

OpenFWI一共12份数据集,共分成了四类:Vel Family、Fault Family、Style Family和Kimberlina Family。本案例主要采用了前两类,其配置信息如下:

image-20240830153600238

image-20240830153613634

其中,每份数据集都包含了波形数据和对应的速度图像。下图展示了每份数据集中速度图像的一个示例。

image-20240830154311787

可以看到,Vel Family包含了地质界面平直和弯曲的两种情况,而Fault Family在此基础上增加了一些地质断层。

每个样本都包含了一张速度图像和五张波形数据,如下图所示。

image-20240830154807670

其中,5个红星排成一排代表地面上的五个震源,70个接收器也同样布置在地面上。地震波向下传播后会反弹回来,接收器每隔0.001秒记录一次数据,共计1000个。因此,生成了一个形状为(5,1000,70)的地震波形数据集。

注意:所有数据并非真实采集的数据,而是模拟生成的。具体细节请阅读OpenFWI: Large-Scale Multi-Structural Benchmark Datasets for Seismic Full Waveform Inversion

3.2 构建dataset API

由于一份数据集由120个数据文件组成,传入所有文件路径是很麻烦的。为了方便读取数据,可以将所有路径打包成一个文本文件。通过依次解析其中的路径,从而读取所有数据。由于这种特殊的读取方式,我们无法使用PaddleScience内置的dataset API,所以自定义了ppsci.data.dataset.FWIDataset

下面给出dataloader的配置代码:

# set dataloader config
dataloader_cfg = {
    "dataset": {
        "name": "FWIDataset",
        "input_keys": ("data",),
        "label_keys": ("real_image",),
        "anno": cfg.TRAIN.dataset.anno,
        "preload": cfg.TRAIN.dataset.preload,
        "sample_ratio": cfg.TRAIN.dataset.sample_ratio,
        "file_size": ctx["file_size"],
        "transform_data": transform_data,
        "transform_label": transform_label,
    },
    "sampler": {
        "name": "BatchSampler",
        "shuffle": cfg.TRAIN.sampler.shuffle,
        "drop_last": cfg.TRAIN.sampler.drop_last,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "use_shared_memory": cfg.TRAIN.use_shared_memory,
    "num_workers": cfg.TRAIN.num_workers,
}
其中,dataset使用我们自定义的FWIDatasetanno传入的是文本文件的路径,它包含了所有数据文件的路径。

3.3 模型构建

本案例的VelocityGAN没有被内置在PaddleScience中,需要额外实现,因此我们自定义了ppsci.arch.VelocityGeneratorppsci.arch.VelocityDiscriminator

模型的构建代码如下:

# set model
model_gen = ppsci.arch.VelocityGenerator(**cfg.MODEL.gen_net)
model_dis = ppsci.arch.VelocityDiscriminator(**cfg.MODEL.dis_net)

参数配置如下:

# model settings
MODEL:
  gen_net:
    input_keys: ["data"]
    output_keys: ["fake_image"]
    dim1: 32
    dim2: 64
    dim3: 128
    dim4: 256
    dim5: 512
    sample_spatial: 1.0
  dis_net:
    input_keys: ["image"]
    output_keys: ["score"]
    dim1: 32
    dim2: 64
    dim3: 128
    dim4: 256

3.4 自定义loss

VelocityGAN的损失函数有点复杂,需要我们自定义实现。PaddleScience提供了用于自定loss函数的API——ppsci.loss.FunctionalLoss。方法为先定义loss函数,再将函数名作为参数传给 FunctionalLoss。需要注意,自定义loss函数的输入输出需要是字典的格式。

3.4.1 Generator的loss

Generator的loss包含了L1 loss 、L2 loss和对抗性损失。这三项loss都有对应的权重,如果某一项 loss 的权重为 0,则表示训练中不添加该 loss 项。

def loss_func_gen(self, output_dict, label_dict, *args):
    """Calculate loss of generator.
        The loss includes L1 loss, L2 loss, and adversarial loss. Each of these losses has a corresponding weight,
        and if the weight of any loss is zero, it means that this loss component is not added during training.

    Args:
        output_dict: Output dict of model.
        label_dict: Label dict.

    Returns:
        Loss of generator.
    """
    l1loss = paddle.nn.L1Loss()
    l2loss = paddle.nn.MSELoss()

    pred = output_dict["fake_image"]
    label = label_dict["real_image"]

    loss_g1v = l1loss(pred, label)
    loss_g2v = l2loss(pred, label)

    loss = (
        self.weight["lambda_g1v"] * loss_g1v + self.weight["lambda_g2v"] * loss_g2v
    )

    loss_adv = -paddle.mean(self.model_dis({"image": pred})["score"])

    loss += self.weight["lambda_adv"] * loss_adv

    return {"loss_g": loss}

3.4.2 Discriminator的loss

Discriminator的loss包含了Wasserstein损失和梯度惩罚。其中,只有梯度惩罚项有权重参数。

def loss_func_dis(self, output_dict, label_dict, *args):
    """Calculate loss of discriminator.
        The discriminator's loss includes Wasserstein loss and gradient penalty, and only the gradient penalty has a weight parameter.

    Args:
        output_dict: Output dict of model.
        label_dict: Label dict.

    Returns:
        Loss of discriminator.
    """
    pred = output_dict["fake_image"]
    pred.stop_gradient = True
    label = label_dict["real_image"]

    gradient_penalty = self.compute_gradient_penalty(label, pred)

    loss_real = paddle.mean(self.model_dis({"image": label})["score"])
    loss_fake = paddle.mean(self.model_dis({"image": pred})["score"])

    loss = -loss_real + loss_fake + gradient_penalty * self.weight["lambda_gp"]

    return {"loss_d": loss}

def compute_gradient_penalty(self, real_samples, fake_samples):
    """Calculate the gradient penalty.
        Generate a random interpolation factor, create mixed samples, process through the discriminator,
        compute the gradient of the output, apply L2 norm and constrain it to 1, and finally obtain the gradient penalty.

    Args:
        real_samples: Ground truth data from dataset.
        fake_samples: Generated data from generator.

    Returns:
        Gradient penalty.
    """
    alpha = paddle.rand([real_samples.shape[0], 1, 1, 1], dtype=real_samples.dtype)
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates.stop_gradient = False  # Allow gradients to be calculated
    d_interpolates = self.model_dis({"image": interpolates})["score"]

    gradients = paddle.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.reshape([gradients.shape[0], -1])
    gradient_penalty = paddle.mean((paddle.norm(gradients, p=2, axis=1) - 1) ** 2)
    return gradient_penalty

注意:

pred.stop_gradient = True

表示pred变量不参与梯度计算。这是因为pred仅作为Discriminator的输入,不需要考虑它的梯度。并且,pred是Generator的输出,如果不停止梯度计算,Generator的参数梯度会在判别器训练的时候累加,并最终影响生成器第一个批次的训练。

3.5 约束构建

本案例使用ppsci.constraint.SupervisedConstraint构建约束。

构建代码如下:

# set constraint
constraint_gen = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen),
    output_expr={"fake_image": lambda out: out["fake_image"]},
    name="cst_gen",
)
constraint_gen_dict = {constraint_gen.name: constraint_gen}

constraint_dis = ppsci.constraint.SupervisedConstraint(
    dataloader_cfg=dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(dis_funcs.loss_func_dis),
    output_expr={"fake_image": lambda out: out["fake_image"]},
    name="cst_dis",
)
constraint_dis_dict = {constraint_dis.name: constraint_dis}

其中,output_expr指定了如何构建output_dict,而name为约束的名字,方便后续对其索引。

约束构建完成后,需要创建成字典的形式,方便后续传入给ppsci.solver.Solver

3.6 优化器构建

VelocityGAN使用AdamW优化器,可直接调用ppsci.optimizer.AdamW构建,代码如下:

# set optimizer
optimizer = ppsci.optimizer.AdamW(
    learning_rate=cfg.TRAIN.learning_rate, weight_decay=cfg.TRAIN.weight_decay
)
optimizer_g = optimizer(model_gen)
optimizer_d = optimizer(model_dis)

3.7 Solver构建

将构建好的模型、约束、优化器和其它参数传递给 ppsci.solver.Solver

# initialize solver
solver_gen = ppsci.solver.Solver(
    model=model_gen,
    output_dir=cfg.output_dir,
    constraint=constraint_gen_dict,
    optimizer=optimizer_g,
    epochs=cfg.TRAIN.epochs_gen,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
)

solver_dis = ppsci.solver.Solver(
    model=model_gen,
    output_dir=cfg.output_dir,
    constraint=constraint_dis_dict,
    optimizer=optimizer_d,
    epochs=cfg.TRAIN.epochs_dis,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
)

3.8 模型训练

# training
for i in range(cfg.TRAIN.epochs):
    logger.message(f"\nEpoch: {i + 1}\n")
    solver_dis.train()
    solver_gen.train()

3.9 自定义metric

本案例的评估指标为:MAE(Mean Absolute Error), RMSE(Root Mean Squared Error)和SSIM(Structural SIMilarity)。其中,PaddleScience提供了MAE和RMSE的API,而SSIM需要我们额外实现。

PaddleScience提供了用于自定metric函数的API——ppsci.metric.FunctionalMetric。方法为先定义metric函数,再将函数名作为参数传给 FunctionalMetric。需要注意,自定义metric函数的输入输出需要是字典的格式。

SSIM的实现代码如下:

class SSIM(paddle.nn.Layer):
    """
    SSIM is used to measure the similarity between two images.

    Attributes:
        window_size (int): The size of the gaussian window used for computing SSIM. Defaults to 11.
        size_average (bool): If True, the SSIM values across spatial dimensions are averaged. Defaults to True.

    Methods:
        forward(img1, img2): Computes the SSIM score between two images using a gaussian filter defined by `window`.
    """

    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        _, channel, _, _ = img1.shape

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            if img1.place.is_gpu_place():
                window = window.cuda(img1.place.gpu_device_id())
            window = window.astype(img1.dtype)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)


def gaussian(window_size, sigma):
    gauss = paddle.to_tensor(
        data=[
            exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
            for x in range(window_size)
        ],
        dtype="float32",
    )
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = (
        paddle.mm(_1D_window, _1D_window.t())
        .astype("float32")
        .unsqueeze(0)
        .unsqueeze(0)
    )
    window = _2D_window.expand([channel, 1, window_size, window_size])
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = paddle.nn.functional.conv2d(
        x=img1, weight=window, padding=window_size // 2, groups=channel
    )
    mu2 = paddle.nn.functional.conv2d(
        x=img2, weight=window, padding=window_size // 2, groups=channel
    )

    mu1_sq = mu1.pow(y=2)
    mu2_sq = mu2.pow(y=2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = (
        paddle.nn.functional.conv2d(
            x=img1 * img1, weight=window, padding=window_size // 2, groups=channel
        )
        - mu1_sq
    )
    sigma2_sq = (
        paddle.nn.functional.conv2d(
            x=img2 * img2, weight=window, padding=window_size // 2, groups=channel
        )
        - mu2_sq
    )
    sigma12 = (
        paddle.nn.functional.conv2d(
            x=img1 * img2, weight=window, padding=window_size // 2, groups=channel
        )
        - mu1_mu2
    )

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = (
        (2 * mu1_mu2 + C1)
        * (2 * sigma12 + C2)
        / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    )

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(axis=1).mean(axis=1).mean(axis=1)


def ssim_metirc(output_dict, label_dict):
    ssim_loss = SSIM(window_size=11)
    metric_dict = {}

    for key in label_dict:
        ssim = ssim_loss(label_dict[key] / 2 + 0.5, output_dict[key] / 2 + 0.5)
        metric_dict[key] = ssim

    return metric_dict

3.10 Validator构建

本案例使用ppsci.validate.SupervisedValidator构建评估器。

# set validator
validator = ppsci.validate.SupervisedValidator(
    dataloader_cfg=valid_dataloader_cfg,
    loss=ppsci.loss.MAELoss("mean"),
    output_expr={"real_image": lambda out: out["fake_image"]},
    metric={
        "MAE": ppsci.metric.MAE(),
        "RMSE": ppsci.metric.RMSE(),
        "SSIM": ppsci.metric.FunctionalMetric(func_module.ssim_metirc),
    },
    name="val",
)
validator_dict = {validator.name: validator}

3.11 模型评估

将模型、评估器和权重路径传递给ppsci.solver.Solver后,通过solver.eval()启动评估。

# initialize solver
solver = ppsci.solver.Solver(
    model=model_gen,
    validator=validator_dict,
    pretrained_model_path=cfg.EVAL.pretrained_model_path,
)

# evaluation
solver.eval()

3.12 可视化

评估完成后,我们以图片的形式对结果进行可视化,代码如下:

# visualization
if cfg.VIS.vis:
    with solver.no_grad_context_manager(True):
        for batch_idx, (input_, label_, _) in enumerate(validator.data_loader):
            if batch_idx + 1 > cfg.VIS.vb:
                break
            fake_image = model_gen(input_)["fake_image"].numpy()
            real_image = label_["real_image"].numpy()
            for i in range(cfg.VIS.vsa):
                plot_velocity(
                    fake_image[i, 0],
                    real_image[i, 0],
                    f"{cfg.output_dir}/V_{batch_idx}_{i}.png",
                )
    print(f"The visualizations are saved to {cfg.output_dir}")

4. 完整代码

velocityGAN.py
import json
import os
import sys

import functions as func_module
import hydra
import paddle
from functions import plot_velocity
from omegaconf import DictConfig

import ppsci
from ppsci.utils import logger

os.environ["FLAGS_embedding_deterministic"] = "1"
os.environ["FLAGS_cudnn_deterministic"] = "1"
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
os.environ["NCCL_ALGO"] = "Tree"


def evaluate(cfg: DictConfig):
    # get dataset configuration information
    with open("dataset_config.json") as f:
        try:
            ctx = json.load(f)[cfg.DATASET]
        except KeyError:
            print("Unsupported dataset.")
            sys.exit()

    if cfg.file_size is not None:
        ctx["file_size"] = cfg.file_size

    # get data transformation
    transform_data, transform_label = func_module.create_transform(ctx, cfg.k)

    # set model
    model_gen = ppsci.arch.VelocityGenerator(**cfg.MODEL.gen_net)

    # set valid_dataloader_cfg
    valid_dataloader_cfg = {
        "dataset": {
            "name": "FWIDataset",
            "input_keys": ("data",),
            "label_keys": ("real_image",),
            "anno": cfg.EVAL.dataset.anno,
            "preload": cfg.EVAL.dataset.preload,
            "sample_ratio": cfg.EVAL.dataset.sample_ratio,
            "file_size": ctx["file_size"],
            "transform_data": transform_data,
            "transform_label": transform_label,
        },
        "batch_size": cfg.EVAL.batch_size,
        "use_shared_memory": cfg.EVAL.use_shared_memory,
        "num_workers": cfg.EVAL.num_workers,
    }

    # set validator
    validator = ppsci.validate.SupervisedValidator(
        dataloader_cfg=valid_dataloader_cfg,
        loss=ppsci.loss.MAELoss("mean"),
        output_expr={"real_image": lambda out: out["fake_image"]},
        metric={
            "MAE": ppsci.metric.MAE(),
            "RMSE": ppsci.metric.RMSE(),
            "SSIM": ppsci.metric.FunctionalMetric(func_module.ssim_metirc),
        },
        name="val",
    )
    validator_dict = {validator.name: validator}

    # initialize solver
    solver = ppsci.solver.Solver(
        model=model_gen,
        validator=validator_dict,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
    )

    # evaluation
    solver.eval()

    # visualization
    if cfg.VIS.vis:
        with solver.no_grad_context_manager(True):
            for batch_idx, (input_, label_, _) in enumerate(validator.data_loader):
                if batch_idx + 1 > cfg.VIS.vb:
                    break
                fake_image = model_gen(input_)["fake_image"].numpy()
                real_image = label_["real_image"].numpy()
                for i in range(cfg.VIS.vsa):
                    plot_velocity(
                        fake_image[i, 0],
                        real_image[i, 0],
                        f"{cfg.output_dir}/V_{batch_idx}_{i}.png",
                    )
        print(f"The visualizations are saved to {cfg.output_dir}")


def train(cfg: DictConfig):
    # get dataset configuration information
    with open(cfg.DATASET_CONFIG) as f:
        try:
            ctx = json.load(f)[cfg.DATASET]
        except KeyError:
            print("Unsupported dataset.")
            sys.exit()

    if cfg.file_size is not None:
        ctx["file_size"] = cfg.file_size

    # get data transformation
    transform_data, transform_label = func_module.create_transform(ctx, cfg.k)

    # set model
    model_gen = ppsci.arch.VelocityGenerator(**cfg.MODEL.gen_net)
    model_dis = ppsci.arch.VelocityDiscriminator(**cfg.MODEL.dis_net)

    # set class for loss function
    gen_funcs = func_module.GenFuncs(model_dis, cfg.WEIGHT_DICT.gen)
    dis_funcs = func_module.DisFuncs(model_dis, cfg.WEIGHT_DICT.dis)

    # set dataloader config
    dataloader_cfg = {
        "dataset": {
            "name": "FWIDataset",
            "input_keys": ("data",),
            "label_keys": ("real_image",),
            "anno": cfg.TRAIN.dataset.anno,
            "preload": cfg.TRAIN.dataset.preload,
            "sample_ratio": cfg.TRAIN.dataset.sample_ratio,
            "file_size": ctx["file_size"],
            "transform_data": transform_data,
            "transform_label": transform_label,
        },
        "sampler": {
            "name": "BatchSampler",
            "shuffle": cfg.TRAIN.sampler.shuffle,
            "drop_last": cfg.TRAIN.sampler.drop_last,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "use_shared_memory": cfg.TRAIN.use_shared_memory,
        "num_workers": cfg.TRAIN.num_workers,
    }

    # set constraint
    constraint_gen = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(gen_funcs.loss_func_gen),
        output_expr={"fake_image": lambda out: out["fake_image"]},
        name="cst_gen",
    )
    constraint_gen_dict = {constraint_gen.name: constraint_gen}

    constraint_dis = ppsci.constraint.SupervisedConstraint(
        dataloader_cfg=dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(dis_funcs.loss_func_dis),
        output_expr={"fake_image": lambda out: out["fake_image"]},
        name="cst_dis",
    )
    constraint_dis_dict = {constraint_dis.name: constraint_dis}

    # set optimizer
    optimizer = ppsci.optimizer.AdamW(
        learning_rate=cfg.TRAIN.learning_rate, weight_decay=cfg.TRAIN.weight_decay
    )
    optimizer_g = optimizer(model_gen)
    optimizer_d = optimizer(model_dis)

    # initialize solver
    solver_gen = ppsci.solver.Solver(
        model=model_gen,
        output_dir=cfg.output_dir,
        constraint=constraint_gen_dict,
        optimizer=optimizer_g,
        epochs=cfg.TRAIN.epochs_gen,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_gen,
    )

    solver_dis = ppsci.solver.Solver(
        model=model_gen,
        output_dir=cfg.output_dir,
        constraint=constraint_dis_dict,
        optimizer=optimizer_d,
        epochs=cfg.TRAIN.epochs_dis,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch_dis,
    )

    # training
    for i in range(cfg.TRAIN.epochs):
        logger.message(f"\nEpoch: {i + 1}\n")
        solver_dis.train()
        solver_gen.train()

    # save model weight
    paddle.save(
        model_gen.state_dict(), os.path.join(cfg.output_dir, "model_gen.pdparams")
    )


@hydra.main(version_base=None, config_path="./conf", config_name="velocityGAN.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

5. 结果展示

使用FlatVel-A数据集的训练结果。

MAE RMSE SSIM
0.0669 0.0947 0.8511

image-20240914192445180

image-20240914192456002

6. 参考文献