跳转至

FunDiff

Warning

本文档仅复现了 Fundiff 论文中的 turbulence_mass_transfer 任务

python main.py -cn fae.yaml

python main.py -cn diffusion.yaml FAE.pretrained_model_path=/your/fae/pretrained/model/path
python main.py -cn diffusion.yaml mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/fundiff/fundiff_turbulence_mass_transfer_dit_pretrained.pdparams
预训练模型 指标
fundiff_turbulence_mass_transfer_dit_pretrained.pdparams Mean relative p error: 0.066
Max relative p error: 0.159
Min relative p error: 0.029
Std relative p error: 0.027
Mean relative sdf error: 0.085
Max relative sdf error: 0.307
Min relative sdf error: 0.022
Std relative sdf error: 0.0499

1. 背景简介

生成模型(尤其是扩散模型和流匹配)的最新进展在图像和视频等离散数据的合成方面取得了显著成功。然而,将这些模型应用于物理应用仍然具有挑战性,因为感兴趣的物理量是由复杂物理定律支配的连续函数。本文介绍了 \(\textbf{FunDiff}\),这是一个用于函数空间生成模型的全新框架。FunDiff 将潜在扩散过程与函数自编码器架构相结合,以处理具有不同离散化程度的输入函数,生成可在任意位置求值的连续函数,并无缝地集成物理先验。这些先验通过架构约束或基于物理的损失函数来强制执行,从而确保生成的样本满足基本物理定律。作者从理论上建立了函数空间中密度估计的极小极大最优性保证,表明基于扩散的估计器在合适的正则条件下能够实现最佳收敛速度。结果中展示了 FunDiff 在流体动力学和固体力学等各种应用中的实际有效性。实证结果表明,作者的方法能够生成物理上一致的样本,与目标分布高度一致,并且对噪声数据和低分辨率数据表现出鲁棒性。

2. 问题定义

给定已知物理场 \(u\)\(v\),求解物理场 \(p\) 以及 \(sdf\)

3. 问题求解

接下来开始讲解如何将问题一步一步地转化为 PaddleScience 代码,用深度学习的方法求解该问题。 为了快速理解 PaddleScience,接下来仅对模型构建、方程构建、计算域构建等关键步骤进行阐述,而其余细节请参考 API文档

3.1 训练 FAE

3.1.1 FAE 模型构建

在 FuncDiff 模型中,FAE 模块采用了 Perceiver 的架构,其输入为物理场 \(x\) 和查询坐标 \(coords\),输出是某个物理场在查询坐标上的值 \(u\),因此模型构建代码如下

fae.yaml
# model settings
FAE:
  input_keys: [coords, x]
  output_keys: [u]
  # x_dim: [2, 200, 100, 1]
  # c_dim: [2, 256, 512]

  encoder:
    in_dim: 1
    patch_size: [10, 5]
    emb_dim: 256
    num_latents: 256
    grid_size: [200, 100]
    depth: 8
    num_heads: 8
    mlp_ratio: 2
    layer_norm_eps: 1e-05
  decoder:
    in_dim: 2
    period: null
    fourier_freq: 1.0
    dec_emb_dim: 256
    dec_depth: 4
    dec_num_heads: 8
    mlp_ratio: 2
    num_mlp_layers: 2
    out_dim: 1
    layer_norm_eps: 1e-05
main.py
            raise


def train_fae(cfg: DictConfig):
    # Initialize model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    fae = FAE(
        cfg.FAE.input_keys,

3.1.2 约束构建

FAE 使用 auto encoder decoder 的训练范式,因此标签即是输入 \(u\)

main.py
    encoder,
    decoder,
)

# init constraint
train_dataloader_cfg = {
    "dataset": {
        "name": "TMTDataset",
        "input_keys": cfg.FAE.input_keys,
        "label_keys": cfg.FAE.output_keys,
        "data_path": cfg.DATA_PATH,
        "num_train": cfg.num_train,
        "mode": "train",
        "stage": cfg.stage,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 0,
}

sup_cst = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    loss=ppsci.loss.MSELoss(),
)
# NOTE: A hacky way to plug the batch parser into the constraint.
sample_batch = next(iter(sup_cst.data_loader))
_, h, w, _ = sample_batch.shape

class PostProcessDataLoader:
    def __init__(self, dataloader, parser: BatchParser):
        self.dataloader = dataloader
        self.parser = parser

    def __iter__(self):
        for batch in self.dataloader:
            yield self.parser.random_query(batch)

    def __len__(self):
        return len(self.dataloader)

sup_cst.data_loader = PostProcessDataLoader(
    sup_cst.data_loader,
    BatchParser(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        cfg.TRAIN.num_queries,
        h,
        w,

3.2 训练 DiT

3.2.1 DiT 模型构建

DiT 的模型构建如下

diffusion.yaml
DIT:
  # input_keys: [u, v, p, sdf]
  # output_keys: [v_t_err]
  in_dim: 512
  depth: 8
  emb_dim: 512
  mlp_ratio: 2
  num_heads: 16
  seq_len: 256
  out_dim: 512
  with_condition: true
main.py
    solver.train()


def train_diffusion(cfg: DictConfig):
    # Initialize fae model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    fae = FAE(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        encoder,
        decoder,
    )
    # Load pretrained fae params and freeze fae
    save_load.load_pretrain(
        fae,
        cfg.FAE.pretrained_model_path,
    )
    fae.freeze()

    # Initialize dit and wrap encoder&decoder&dit into one model for convenience
    dit = DiT(**cfg.DIT)
    model = ModelWrapper(

3.2.2 约束构建

在 FuncDiff 模型中,DiT 的训练使用 rectified flow 算法,其对应的数学公式如下:

\[ \mathcal{L}(\theta) = \mathbb{E}_{\mathbf{z}, t, {\epsilon}} \left[ \left\| \hat{\mathbf{v}}_\theta(\mathbf{x}, t) - (\mathbf{z} - \mathbf{x}) \right\|^2 \right] \]

其对应的前向计算实现代码如下

main.py
class ModelWrapper(nn.Layer):
    def __init__(self, encoder: Encoder, decoder: Decoder, dit: DiT):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder  # need to be wrapped for convenience when loading encoder&decoder&dit params togather in inference
        self.dit = dit

    def forward(self, batch: Dict[str, paddle.Tensor]):
        # define the forward pass for diffusion training process
        # just ignore the non-training branch code
        with paddle.no_grad():
            # data in batch have been downsampled already
            u = batch["u"]
            v = batch["v"]
            z_u = self.encoder(u)
            z_v = self.encoder(v)
            z_c = paddle.concat([z_u, z_v], axis=-1)

            if self.training:
                p = batch["p"]
                sdf = batch["sdf"]
                z_p = self.encoder(p)
                z_sdf = self.encoder(sdf)
                z_1 = paddle.concat([z_p, z_sdf], axis=-1)
                z_0 = paddle.randn(z_1.shape)  # (b, 200, 512)
                t = paddle.uniform(
                    [z_1.shape[0], *[1 for _ in range(z_1.ndim - 1)]],
                    min=0.0,
                    max=1.0,
                )
                z_t = t * (z_1 - z_0) + z_0
                v_t = z_1 - z_0
            else:
                raise

        # only training dit
        v_t_pred = self.dit(z_t, t.flatten(), z_c)

        if self.training:
            return {
                "v_t_err": v_t - v_t_pred,
            }

整体约束构建如下

main.py
    decoder,
    dit,
)

# init constraint
train_dataloader_cfg = {
    "dataset": {
        "name": "TMTDataset",
        "input_keys": cfg.FAE.input_keys,
        "label_keys": cfg.FAE.output_keys,
        "data_path": cfg.DATA_PATH,
        "num_train": cfg.num_train,
        "mode": "train",
        "stage": cfg.stage,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
}
sup_cst = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(
        lambda inp, label, weight: {"v_t": (inp["v_t_err"] ** 2).mean()}
    ),
)

# NOTE: A hacky way to plug the batch parser into the constraint.
sample_batch = next(iter(sup_cst.data_loader))
_, h, w, _ = sample_batch.shape

class PostProcessDataLoader:
    def __init__(self, dataloader, parser: BatchParser):
        self.dataloader = dataloader
        self.parser = parser

    def __iter__(self):
        for batch in self.dataloader:
            yield self.parser.random_downsample(batch)

    def __len__(self):
        return len(self.dataloader)

sup_cst.data_loader = PostProcessDataLoader(
    sup_cst.data_loader,
    BatchParser(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        None,
        h,
        w,

3.3 超参数设定

FAE 使用 100000 步训练步数,0.001 的初始学习率。

fae.yaml
# training settings
TRAIN:
  steps: 100000
  epochs: 5000
  iters_per_epoch: -1
  num_queries: 1024
  solution: [1, 2, 4]
  save_freq: 0
  eval_during_train: false
  eval_freq: 10
  optim: adam
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  weight_decay: 1e-5
  clip_norm: 1.0

DiT 使用 100000 步训练步数,0.001 的初始学习率。

diffusion.yaml
# training settings
TRAIN:
  steps: 100000
  epochs: 10000
  iters_per_epoch: -1
  solution: [1, 2, 4]
  save_freq: 0
  eval_during_train: false
  eval_freq: 10
  optim: adam
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  weight_decay: 1e-5
  clip_norm: 1.0

3.4 优化器构建

训练过程会调用优化器来更新模型参数,FAE 和 DiT 均选择较为常用的 Adam 优化器,并配合使用机器学习中常用的 ExponentialDecay 学习率调整策略。

fae.yaml
# training settings
TRAIN:
  steps: 100000
  epochs: 5000
  iters_per_epoch: -1
  num_queries: 1024
  solution: [1, 2, 4]
  save_freq: 0
  eval_during_train: false
  eval_freq: 10
  optim: adam
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  weight_decay: 1e-5
  clip_norm: 1.0
main.py
logger.debug(
    f"cfg.TRAIN.lr_scheduler.warmup_epoch = {cfg.TRAIN.lr_scheduler.warmup_epoch}"
)

# Create learning rate schedule and optimizer
lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(
    epochs=cfg.TRAIN.epochs,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch,
    **cfg.TRAIN.lr_scheduler,
)()
optimizer = ppsci.optimizer.AdamW(
    lr,
    beta1=cfg.TRAIN.beta1,
    beta2=cfg.TRAIN.beta2,
diffusion.yaml
# training settings
TRAIN:
  steps: 100000
  epochs: 10000
  iters_per_epoch: -1
  solution: [1, 2, 4]
  save_freq: 0
  eval_during_train: false
  eval_freq: 10
  optim: adam
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  weight_decay: 1e-5
  clip_norm: 1.0

  lr_scheduler:
    # epochs: ${TRAIN.epochs}
    learning_rate: 0.001
    gamma: 0.9
    decay_steps: 5000
    by_epoch: false
    warmup_epoch: 2000

  batch_size: 64
  pretrained_model_path: null
  checkpoint_path: null
main.py
logger.debug(
    f"cfg.TRAIN.lr_scheduler.warmup_epoch = {cfg.TRAIN.lr_scheduler.warmup_epoch}"
)

# Create learning rate schedule and optimizer
lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(
    epochs=cfg.TRAIN.epochs,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch,
    **cfg.TRAIN.lr_scheduler,
)()
optimizer = ppsci.optimizer.AdamW(
    lr,
    beta1=cfg.TRAIN.beta1,
    beta2=cfg.TRAIN.beta2,

3.5 模型训练

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 ppsci.solver.Solver,然后启动训练即可。

    weight_decay=cfg.TRAIN.weight_decay,
    grad_clip=nn.ClipGradByGlobalNorm(cfg.TRAIN.clip_norm),
)(fae)

# init solver
solver = ppsci.solver.Solver(
    fae,
    {"sup": sup_cst},
    optimizer=optimizer,
    weight_decay=cfg.TRAIN.weight_decay,
    grad_clip=nn.ClipGradByGlobalNorm(cfg.TRAIN.clip_norm),
)(dit)

# init solver
solver = ppsci.solver.Solver(
    model,
    {"sup": sup_cst},
    optimizer=optimizer,

4. 完整代码

main.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
NOTE: Code below is reproduced from https://github.com/sifanexisted/fundiff with paddle backend
"""
from __future__ import annotations

from os import path as osp
from typing import Dict

import hydra
import matplotlib.pyplot as plt
import numpy as np
import paddle
from matplotlib import gridspec
from model import FAE
from model import Decoder
from model import DiT
from model import Encoder
from omegaconf import DictConfig
from paddle import nn
from tqdm import tqdm

import ppsci
from ppsci.data.dataset.tmtdataset import BatchParser
from ppsci.utils import save_load
from ppsci.utils.misc import logger

dtype = paddle.get_default_dtype()


class ModelWrapper(nn.Layer):
    def __init__(self, encoder: Encoder, decoder: Decoder, dit: DiT):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder  # need to be wrapped for convenience when loading encoder&decoder&dit params togather in inference
        self.dit = dit

    def forward(self, batch: Dict[str, paddle.Tensor]):
        # define the forward pass for diffusion training process
        # just ignore the non-training branch code
        with paddle.no_grad():
            # data in batch have been downsampled already
            u = batch["u"]
            v = batch["v"]
            z_u = self.encoder(u)
            z_v = self.encoder(v)
            z_c = paddle.concat([z_u, z_v], axis=-1)

            if self.training:
                p = batch["p"]
                sdf = batch["sdf"]
                z_p = self.encoder(p)
                z_sdf = self.encoder(sdf)
                z_1 = paddle.concat([z_p, z_sdf], axis=-1)
                z_0 = paddle.randn(z_1.shape)  # (b, 200, 512)
                t = paddle.uniform(
                    [z_1.shape[0], *[1 for _ in range(z_1.ndim - 1)]],
                    min=0.0,
                    max=1.0,
                )
                z_t = t * (z_1 - z_0) + z_0
                v_t = z_1 - z_0
            else:
                raise

        # only training dit
        v_t_pred = self.dit(z_t, t.flatten(), z_c)

        if self.training:
            return {
                "v_t_err": v_t - v_t_pred,
            }
        else:
            raise


def train_fae(cfg: DictConfig):
    # Initialize model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    fae = FAE(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        encoder,
        decoder,
    )

    # init constraint
    train_dataloader_cfg = {
        "dataset": {
            "name": "TMTDataset",
            "input_keys": cfg.FAE.input_keys,
            "label_keys": cfg.FAE.output_keys,
            "data_path": cfg.DATA_PATH,
            "num_train": cfg.num_train,
            "mode": "train",
            "stage": cfg.stage,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 0,
    }

    sup_cst = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        loss=ppsci.loss.MSELoss(),
    )
    # NOTE: A hacky way to plug the batch parser into the constraint.
    sample_batch = next(iter(sup_cst.data_loader))
    _, h, w, _ = sample_batch.shape

    class PostProcessDataLoader:
        def __init__(self, dataloader, parser: BatchParser):
            self.dataloader = dataloader
            self.parser = parser

        def __iter__(self):
            for batch in self.dataloader:
                yield self.parser.random_query(batch)

        def __len__(self):
            return len(self.dataloader)

    sup_cst.data_loader = PostProcessDataLoader(
        sup_cst.data_loader,
        BatchParser(
            cfg.FAE.input_keys,
            cfg.FAE.output_keys,
            cfg.TRAIN.num_queries,
            h,
            w,
            cfg.TRAIN.solution,
        ),
    )
    sup_cst.data_iter = iter(sup_cst.data_loader)

    # reset epochs & iters_per_epoch
    cfg.TRAIN.iters_per_epoch = len(sup_cst.data_loader)
    logger.debug(f"cfg.TRAIN.iters_per_epoch = {cfg.TRAIN.iters_per_epoch}")
    cfg.TRAIN.epochs = cfg.TRAIN.steps // cfg.TRAIN.iters_per_epoch
    cfg.TRAIN.lr_scheduler.warmup_epoch /= cfg.TRAIN.iters_per_epoch
    logger.debug(
        f"cfg.TRAIN.lr_scheduler.warmup_epoch = {cfg.TRAIN.lr_scheduler.warmup_epoch}"
    )

    # Create learning rate schedule and optimizer
    lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(
        epochs=cfg.TRAIN.epochs,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch,
        **cfg.TRAIN.lr_scheduler,
    )()
    optimizer = ppsci.optimizer.AdamW(
        lr,
        beta1=cfg.TRAIN.beta1,
        beta2=cfg.TRAIN.beta2,
        epsilon=cfg.TRAIN.eps,
        weight_decay=cfg.TRAIN.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(cfg.TRAIN.clip_norm),
    )(fae)

    # init solver
    solver = ppsci.solver.Solver(
        fae,
        {"sup": sup_cst},
        optimizer=optimizer,
        cfg=cfg,
    )
    # train
    solver.train()


def train_diffusion(cfg: DictConfig):
    # Initialize fae model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    fae = FAE(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        encoder,
        decoder,
    )
    # Load pretrained fae params and freeze fae
    save_load.load_pretrain(
        fae,
        cfg.FAE.pretrained_model_path,
    )
    fae.freeze()

    # Initialize dit and wrap encoder&decoder&dit into one model for convenience
    dit = DiT(**cfg.DIT)
    model = ModelWrapper(
        encoder,
        decoder,
        dit,
    )

    # init constraint
    train_dataloader_cfg = {
        "dataset": {
            "name": "TMTDataset",
            "input_keys": cfg.FAE.input_keys,
            "label_keys": cfg.FAE.output_keys,
            "data_path": cfg.DATA_PATH,
            "num_train": cfg.num_train,
            "mode": "train",
            "stage": cfg.stage,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
    }
    sup_cst = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(
            lambda inp, label, weight: {"v_t": (inp["v_t_err"] ** 2).mean()}
        ),
    )

    # NOTE: A hacky way to plug the batch parser into the constraint.
    sample_batch = next(iter(sup_cst.data_loader))
    _, h, w, _ = sample_batch.shape

    class PostProcessDataLoader:
        def __init__(self, dataloader, parser: BatchParser):
            self.dataloader = dataloader
            self.parser = parser

        def __iter__(self):
            for batch in self.dataloader:
                yield self.parser.random_downsample(batch)

        def __len__(self):
            return len(self.dataloader)

    sup_cst.data_loader = PostProcessDataLoader(
        sup_cst.data_loader,
        BatchParser(
            cfg.FAE.input_keys,
            cfg.FAE.output_keys,
            None,
            h,
            w,
            cfg.TRAIN.solution,
        ),
    )
    sup_cst.data_iter = iter(sup_cst.data_loader)

    # reset epochs & iters_per_epoch
    cfg.TRAIN.iters_per_epoch = len(sup_cst.data_loader)
    logger.debug(f"cfg.TRAIN.iters_per_epoch = {cfg.TRAIN.iters_per_epoch}")
    cfg.TRAIN.epochs = cfg.TRAIN.steps // cfg.TRAIN.iters_per_epoch
    cfg.TRAIN.lr_scheduler.warmup_epoch /= cfg.TRAIN.iters_per_epoch
    logger.debug(
        f"cfg.TRAIN.lr_scheduler.warmup_epoch = {cfg.TRAIN.lr_scheduler.warmup_epoch}"
    )

    # Create learning rate schedule and optimizer
    lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(
        epochs=cfg.TRAIN.epochs,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch,
        **cfg.TRAIN.lr_scheduler,
    )()
    optimizer = ppsci.optimizer.AdamW(
        lr,
        beta1=cfg.TRAIN.beta1,
        beta2=cfg.TRAIN.beta2,
        epsilon=cfg.TRAIN.eps,
        weight_decay=cfg.TRAIN.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(cfg.TRAIN.clip_norm),
    )(dit)

    # init solver
    solver = ppsci.solver.Solver(
        model,
        {"sup": sup_cst},
        optimizer=optimizer,
        cfg=cfg,
    )
    # train
    solver.train()


@paddle.no_grad()
def evaluate(cfg: DictConfig):
    # Initialize encoder & decoder & dit model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    dit = DiT(**cfg.DIT)
    # wrap encoder&decoder&dit into one model for convenience
    model = ModelWrapper(
        encoder,
        decoder,
        dit,
    )
    save_load.load_pretrain(model, cfg.EVAL.pretrained_model_path)

    # init evaluate data
    eval_dataset = ppsci.data.dataset.TMTDataset(
        input_keys=cfg.FAE.input_keys,
        label_keys=cfg.FAE.output_keys,
        data_path=cfg.DATA_PATH,
        num_train=cfg.num_train,
        mode="test",
        stage=cfg.stage,
    )
    eval_loader = paddle.io.DataLoader(eval_dataset, batch_size=cfg.EVAL.batch_size)

    h, w = 200, 100
    x_coords = np.linspace(0, 1, h, dtype=dtype)
    y_coords = np.linspace(0, 1, w, dtype=dtype)
    x_coords, y_coords = np.meshgrid(x_coords, y_coords, indexing="ij")
    coords = np.hstack([x_coords.reshape(-1, 1), y_coords.reshape(-1, 1)], dtype=dtype)[
        None, ...
    ]

    noise_level = 0.1
    d = 1
    u_input_list = []
    v_input_list = []
    p_pred_list = []
    sdf_pred_list = []
    p_true_list = []
    sdf_true_list = []

    def sample_ode(
        z0: paddle.Tensor = None,
        c: paddle.Tensor = None,
        num_steps: int = None,
        use_conditioning: bool = False,
    ) -> paddle.Tensor:
        dt = 1 / num_steps
        traj = [z0]

        z = z0
        for i in tqdm(range(num_steps)):
            t = paddle.ones([z.shape[0]]) * i / num_steps
            if use_conditioning:
                pred = dit(z, t, c)
            else:
                pred = dit(z, t)
            z = z + pred * dt
            traj.append(z)
        return z, traj

    for batch in tqdm(eval_loader):
        u: paddle.Tensor = batch[:, ::d, ::d, 0:1]
        v: paddle.Tensor = batch[:, ::d, ::d, 1:2]
        p: paddle.Tensor = batch[..., 2:3]
        sdf: paddle.Tensor = batch[..., 3:4]

        # add noise to input field
        u = u + noise_level * paddle.randn(u.shape)
        v = v + noise_level * paddle.randn(v.shape)

        # compute condition latent code from given field
        z_u = encoder(u)
        z_v = encoder(v)
        z_c = paddle.concat([z_u, z_v], axis=-1)  # (b, l, 2c)

        # random sample z0 from standard normal distribution
        z0 = paddle.randn(shape=z_c.shape)

        # integrate ODE
        z1_new, _ = sample_ode(
            z0=z0,
            c=z_c,
            num_steps=cfg.EVAL.num_steps,
            use_conditioning=cfg.EVAL.use_conditioning,
        )

        # decode latent code to output field
        c_dim = z_c.shape[-1]
        z_p_new = z1_new[..., : c_dim // 2]
        z_sdf_new = z1_new[..., c_dim // 2 :]

        p_pred = decoder(z_p_new, coords)
        sdf_pred = decoder(z_sdf_new, coords)

        p_pred = p_pred.reshape([-1, h, w])
        sdf_pred = sdf_pred.reshape([-1, h, w])

        u_input_list.append(u)
        v_input_list.append(v)

        p_pred_list.append(p_pred)
        sdf_pred_list.append(sdf_pred)

        p_true_list.append(p)
        sdf_true_list.append(sdf)

    # Concatenate all results
    u_input = paddle.concat(u_input_list, axis=0).squeeze()
    v_input = paddle.concat(v_input_list, axis=0).squeeze()
    p_pred = paddle.concat(p_pred_list, axis=0)
    sdf_pred = paddle.concat(sdf_pred_list, axis=0)
    p_true = paddle.concat(p_true_list, axis=0).squeeze()
    sdf_true = paddle.concat(sdf_true_list, axis=0).squeeze()

    def compute_error(pred, y):
        return paddle.linalg.norm(
            pred.flatten(1) - y.flatten(1), axis=1, p="fro"
        ) / paddle.linalg.norm(y.flatten(1), axis=1, p="fro")

    # Compute errors
    error = compute_error(p_pred, p_true)
    logger.info(f"Mean relative p error: {paddle.mean(error).item():.4f}")
    logger.info(f"Max relative p error: {paddle.max(error).item():.4f}")
    logger.info(f"Min relative p error: {paddle.min(error).item():.4f}")
    logger.info(f"Std relative p error: {paddle.std(error, unbiased=True).item():.4f}")

    error = compute_error(sdf_pred, sdf_true)
    logger.info(f"Mean relative sdf error: {paddle.mean(error).item():.4f}")
    logger.info(f"Max relative sdf error: {paddle.max(error).item():.4f}")
    logger.info(f"Min relative sdf error: {paddle.min(error).item():.4f}")
    logger.info(
        f"Std relative sdf error: {paddle.std(error, unbiased=True).item():.4f}"
    )

    for k in range(u_input.shape[0]):
        if k >= 4:
            break

        fig = plt.figure(figsize=(20, 5))
        gs = gridspec.GridSpec(
            2, 5, width_ratios=[0.8, 1, 1, 1, 0.05], wspace=0.3, hspace=0.3
        )

        ax_input_u = fig.add_subplot(gs[0, 0])
        ax_input_u.set_title("Input U")
        im = ax_input_u.imshow(u_input[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_input_u)

        ax_input_v = fig.add_subplot(gs[1, 0])
        ax_input_v.set_title("Input V")
        im = ax_input_v.imshow(v_input[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_input_v)

        # Reference / Prediction / Error of P
        ax_ref = fig.add_subplot(gs[0, 1])
        ax_ref.set_title("Reference P")
        im = ax_ref.imshow(p_true[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_ref)

        ax_pred = fig.add_subplot(gs[0, 2])
        ax_pred.set_title("Prediction P")
        im = ax_pred.imshow(p_pred[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_pred)

        ax_err = fig.add_subplot(gs[0, 3])
        ax_err.set_title("Absolute Error P")
        im = ax_err.imshow(paddle.abs(p_pred[k].T - p_true[k].T), cmap="jet")
        plt.colorbar(im, ax=ax_err)

        # Reference / Prediction / Error of SDF
        ax_ref2 = fig.add_subplot(gs[1, 1])
        ax_ref2.set_title("Reference SDF")
        im = ax_ref2.imshow(sdf_true[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_ref2)

        ax_pred2 = fig.add_subplot(gs[1, 2])
        ax_pred2.set_title("Prediction SDF")
        im = ax_pred2.imshow(sdf_pred[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_pred2)

        ax_err2 = fig.add_subplot(gs[1, 3])
        ax_err2.set_title("Absolute Error SDF")
        im = ax_err2.imshow(paddle.abs(sdf_pred[k].T - sdf_true[k].T), cmap="jet")
        plt.colorbar(im, ax=ax_err2)

        plt.tight_layout()
        plt.savefig(osp.join(cfg.output_dir, f"result_of_sample_{k}.png"), dpi=300)
        plt.close()


def export(cfg: DictConfig):
    raise NotImplementedError


def inference(cfg: DictConfig):
    raise NotImplementedError


@hydra.main(version_base=None, config_path="./conf", config_name="fae.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        if cfg.stage == "fae":
            train_fae(cfg)
        elif cfg.stage == "dit":
            train_diffusion(cfg)
        else:
            raise ValueError(
                f"cfg.stage should be 'fea', or 'dit', but got {cfg.stage}"
            )
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()

5. 结果展示

在测试集上进行评估,并对部分结果进行展示

result_of_sample_2.jpg

可以看到对于函数\(p(x, coord | u,v)\)\(sdf(x, coord | u,v)\),模型的预测结果和参考结果基本一致。

6. 参考资料