Skip to content

FunDiff

Warning

This document only reproduces the turbulence_mass_transfer task in the Fundiff paper.

Note

Please download the tmt.npy dataset file from https://drive.google.com/drive/folders/1GX5uG_3R-yfuP9nMIk0v7ChuEytYwYPW?usp=drive_link first.

python main.py -cn fae.yaml

python main.py -cn diffusion.yaml FAE.pretrained_model_path=/your/fae/pretrained/model/path
python main.py -cn diffusion.yaml mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/fundiff/fundiff_turbulence_mass_transfer_dit_pretrained.pdparams
Pretrained Model Metrics
fundiff_turbulence_mass_transfer_dit_pretrained.pdparams Mean relative p error: 0.066
Max relative p error: 0.159
Min relative p error: 0.029
Std relative p error: 0.027
Mean relative sdf error: 0.085
Max relative sdf error: 0.307
Min relative sdf error: 0.022
Std relative sdf error: 0.0499

1. Background Introduction

Recent advances in generative models (especially diffusion models and flow matching) have achieved remarkable success in synthesizing discrete data such as images and videos. However, applying these models to physical applications remains challenging because the physical quantities of interest are continuous functions governed by complex physical laws. This paper introduces \(\textbf{FunDiff}\), a novel framework for function space generative models. FunDiff combines latent diffusion processes with functional autoencoder architectures to handle input functions with varying degrees of discretization, generate continuous functions that can be evaluated at arbitrary positions, and seamlessly integrate physical priors. These priors are enforced through architectural constraints or physics-based loss functions, ensuring that generated samples satisfy fundamental physical laws. The authors theoretically establish minimax optimality guarantees for density estimation in function space, showing that diffusion-based estimators can achieve optimal convergence rates under appropriate regularity conditions. Results demonstrate the practical effectiveness of FunDiff in various applications such as fluid dynamics and solid mechanics. Empirical results show that the authors' method is capable of generating physically consistent samples that are highly consistent with the target distribution and exhibit robustness to noisy data and low-resolution data.

2. Problem Definition

Given known physical fields \(u\) and \(v\), solve for physical fields \(p\) and \(sdf\).

3. Problem Solving

Next, we will explain how to convert the problem into PaddleScience code step by step and solve the problem using deep learning methods. In order to quickly understand PaddleScience, only key steps such as model construction, equation construction, and computational domain construction are described below, while other details please refer to API Documentation.

3.1 Training FAE

3.1.1 FAE Model Construction

In the FuncDiff model, the FAE module adopts the Perceiver architecture. Its input is the physical field \(x\) and query coordinates \(coords\), and the output is the value \(u\) of a certain physical field at the query coordinates. Therefore, the model construction code is as follows:

fae.yaml
# model settings
FAE:
  input_keys: [coords, x]
  output_keys: [u]
  # x_dim: [2, 200, 100, 1]
  # c_dim: [2, 256, 512]

  encoder:
    in_dim: 1
    patch_size: [10, 5]
    emb_dim: 256
    num_latents: 256
    grid_size: [200, 100]
    depth: 8
    num_heads: 8
    mlp_ratio: 2
    layer_norm_eps: 1e-05
  decoder:
    in_dim: 2
    period: null
    fourier_freq: 1.0
    dec_emb_dim: 256
    dec_depth: 4
    dec_num_heads: 8
    mlp_ratio: 2
    num_mlp_layers: 2
    out_dim: 1
    layer_norm_eps: 1e-05
main.py
            raise


def train_fae(cfg: DictConfig):
    # Initialize model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    fae = FAE(
        cfg.FAE.input_keys,

3.1.2 Constraint Construction

FAE uses the auto encoder decoder training paradigm, so the label is the input \(u\).

main.py
    encoder,
    decoder,
)

# init constraint
train_dataloader_cfg = {
    "dataset": {
        "name": "TMTDataset",
        "input_keys": cfg.FAE.input_keys,
        "label_keys": cfg.FAE.output_keys,
        "data_path": cfg.DATA_PATH,
        "num_train": cfg.num_train,
        "mode": "train",
        "stage": cfg.stage,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
    "num_workers": 0,
}

sup_cst = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    loss=ppsci.loss.MSELoss(),
)
# NOTE: A hacky way to plug the batch parser into the constraint.
sample_batch = next(iter(sup_cst.data_loader))
_, h, w, _ = sample_batch.shape

class PostProcessDataLoader:
    def __init__(self, dataloader, parser: BatchParser):
        self.dataloader = dataloader
        self.parser = parser

    def __iter__(self):
        for batch in self.dataloader:
            yield self.parser.random_query(batch)

    def __len__(self):
        return len(self.dataloader)

sup_cst.data_loader = PostProcessDataLoader(
    sup_cst.data_loader,
    BatchParser(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        cfg.TRAIN.num_queries,
        h,
        w,

3.2 Training DiT

3.2.1 DiT Model Construction

The model construction of DiT is as follows:

diffusion.yaml
DIT:
  # input_keys: [u, v, p, sdf]
  # output_keys: [v_t_err]
  in_dim: 512
  depth: 8
  emb_dim: 512
  mlp_ratio: 2
  num_heads: 16
  seq_len: 256
  out_dim: 512
  with_condition: true
main.py
    solver.train()


def train_diffusion(cfg: DictConfig):
    # Initialize fae model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    fae = FAE(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        encoder,
        decoder,
    )
    # Load pretrained fae params and freeze fae
    save_load.load_pretrain(
        fae,
        cfg.FAE.pretrained_model_path,
    )
    fae.freeze()

    # Initialize dit and wrap encoder&decoder&dit into one model for convenience
    dit = DiT(**cfg.DIT)
    model = ModelWrapper(

3.2.2 Constraint Construction

In the FuncDiff model, the training of DiT uses the rectified flow algorithm, and its corresponding mathematical formula is as follows:

\[ \mathcal{L}(\theta) = \mathbb{E}_{\mathbf{z}, t, {\epsilon}} \left[ \left\| \hat{\mathbf{v}}_\theta(\mathbf{x}, t) - (\mathbf{z} - \mathbf{x}) \right\|^2 \right] \]

Its corresponding forward calculation implementation code is as follows:

main.py
class ModelWrapper(nn.Layer):
    def __init__(self, encoder: Encoder, decoder: Decoder, dit: DiT):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder  # need to be wrapped for convenience when loading encoder&decoder&dit params togather in inference
        self.dit = dit

    def forward(self, batch: Dict[str, paddle.Tensor]):
        # define the forward pass for diffusion training process
        # just ignore the non-training branch code
        with paddle.no_grad():
            # data in batch have been downsampled already
            u = batch["u"]
            v = batch["v"]
            z_u = self.encoder(u)
            z_v = self.encoder(v)
            z_c = paddle.concat([z_u, z_v], axis=-1)

            if self.training:
                p = batch["p"]
                sdf = batch["sdf"]
                z_p = self.encoder(p)
                z_sdf = self.encoder(sdf)
                z_1 = paddle.concat([z_p, z_sdf], axis=-1)
                z_0 = paddle.randn(z_1.shape)  # (b, 200, 512)
                t = paddle.uniform(
                    [z_1.shape[0], *[1 for _ in range(z_1.ndim - 1)]],
                    min=0.0,
                    max=1.0,
                )
                z_t = t * (z_1 - z_0) + z_0
                v_t = z_1 - z_0
            else:
                raise

        # only training dit
        v_t_pred = self.dit(z_t, t.flatten(), z_c)

        if self.training:
            return {
                "v_t_err": v_t - v_t_pred,
            }

The overall constraint construction is as follows:

main.py
    decoder,
    dit,
)

# init constraint
train_dataloader_cfg = {
    "dataset": {
        "name": "TMTDataset",
        "input_keys": cfg.FAE.input_keys,
        "label_keys": cfg.FAE.output_keys,
        "data_path": cfg.DATA_PATH,
        "num_train": cfg.num_train,
        "mode": "train",
        "stage": cfg.stage,
    },
    "sampler": {
        "name": "BatchSampler",
        "drop_last": False,
        "shuffle": True,
    },
    "batch_size": cfg.TRAIN.batch_size,
}
sup_cst = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    loss=ppsci.loss.FunctionalLoss(
        lambda inp, label, weight: {"v_t": (inp["v_t_err"] ** 2).mean()}
    ),
)

# NOTE: A hacky way to plug the batch parser into the constraint.
sample_batch = next(iter(sup_cst.data_loader))
_, h, w, _ = sample_batch.shape

class PostProcessDataLoader:
    def __init__(self, dataloader, parser: BatchParser):
        self.dataloader = dataloader
        self.parser = parser

    def __iter__(self):
        for batch in self.dataloader:
            yield self.parser.random_downsample(batch)

    def __len__(self):
        return len(self.dataloader)

sup_cst.data_loader = PostProcessDataLoader(
    sup_cst.data_loader,
    BatchParser(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        None,
        h,
        w,

3.3 Hyperparameter Setting

FAE uses 100,000 training steps and an initial learning rate of 0.001.

fae.yaml
# training settings
TRAIN:
  steps: 100000
  epochs: 5000
  iters_per_epoch: -1
  num_queries: 1024
  solution: [1, 2, 4]
  save_freq: 0
  eval_during_train: false
  eval_freq: 10
  optim: adam
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  weight_decay: 1e-5
  clip_norm: 1.0

DiT uses 100,000 training steps and an initial learning rate of 0.001.

diffusion.yaml
# training settings
TRAIN:
  steps: 100000
  epochs: 10000
  iters_per_epoch: -1
  solution: [1, 2, 4]
  save_freq: 0
  eval_during_train: false
  eval_freq: 10
  optim: adam
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  weight_decay: 1e-5
  clip_norm: 1.0

3.4 Optimizer Construction

The training process will call the optimizer to update model parameters. Both FAE and DiT choose the more commonly used Adam optimizer, and combine it with the ExponentialDecay learning rate adjustment strategy commonly used in machine learning.

fae.yaml
# training settings
TRAIN:
  steps: 100000
  epochs: 5000
  iters_per_epoch: -1
  num_queries: 1024
  solution: [1, 2, 4]
  save_freq: 0
  eval_during_train: false
  eval_freq: 10
  optim: adam
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  weight_decay: 1e-5
  clip_norm: 1.0
main.py
logger.debug(
    f"cfg.TRAIN.lr_scheduler.warmup_epoch = {cfg.TRAIN.lr_scheduler.warmup_epoch}"
)

# Create learning rate schedule and optimizer
lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(
    epochs=cfg.TRAIN.epochs,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch,
    **cfg.TRAIN.lr_scheduler,
)()
optimizer = ppsci.optimizer.AdamW(
    lr,
    beta1=cfg.TRAIN.beta1,
    beta2=cfg.TRAIN.beta2,
diffusion.yaml
# training settings
TRAIN:
  steps: 100000
  epochs: 10000
  iters_per_epoch: -1
  solution: [1, 2, 4]
  save_freq: 0
  eval_during_train: false
  eval_freq: 10
  optim: adam
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8
  weight_decay: 1e-5
  clip_norm: 1.0

  lr_scheduler:
    # epochs: ${TRAIN.epochs}
    learning_rate: 0.001
    gamma: 0.9
    decay_steps: 5000
    by_epoch: false
    warmup_epoch: 2000

  batch_size: 64
  pretrained_model_path: null
  checkpoint_path: null
main.py
logger.debug(
    f"cfg.TRAIN.lr_scheduler.warmup_epoch = {cfg.TRAIN.lr_scheduler.warmup_epoch}"
)

# Create learning rate schedule and optimizer
lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(
    epochs=cfg.TRAIN.epochs,
    iters_per_epoch=cfg.TRAIN.iters_per_epoch,
    **cfg.TRAIN.lr_scheduler,
)()
optimizer = ppsci.optimizer.AdamW(
    lr,
    beta1=cfg.TRAIN.beta1,
    beta2=cfg.TRAIN.beta2,

3.5 Model Training

After completing the above settings, you only need to pass the instantiated objects to ppsci.solver.Solver in order, and then start training.

    weight_decay=cfg.TRAIN.weight_decay,
    grad_clip=nn.ClipGradByGlobalNorm(cfg.TRAIN.clip_norm),
)(fae)

# init solver
solver = ppsci.solver.Solver(
    fae,
    {"sup": sup_cst},
    optimizer=optimizer,
    weight_decay=cfg.TRAIN.weight_decay,
    grad_clip=nn.ClipGradByGlobalNorm(cfg.TRAIN.clip_norm),
)(dit)

# init solver
solver = ppsci.solver.Solver(
    model,
    {"sup": sup_cst},
    optimizer=optimizer,

4. Complete Code

main.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
NOTE: Code below is reproduced from https://github.com/sifanexisted/fundiff with paddle backend
"""
from __future__ import annotations

from os import path as osp
from typing import Dict

import hydra
import matplotlib.pyplot as plt
import numpy as np
import paddle
from matplotlib import gridspec
from model import FAE
from model import Decoder
from model import DiT
from model import Encoder
from omegaconf import DictConfig
from paddle import nn
from tqdm import tqdm

import ppsci
from ppsci.data.dataset.tmtdataset import BatchParser
from ppsci.utils import save_load
from ppsci.utils.misc import logger

dtype = paddle.get_default_dtype()


class ModelWrapper(nn.Layer):
    def __init__(self, encoder: Encoder, decoder: Decoder, dit: DiT):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder  # need to be wrapped for convenience when loading encoder&decoder&dit params togather in inference
        self.dit = dit

    def forward(self, batch: Dict[str, paddle.Tensor]):
        # define the forward pass for diffusion training process
        # just ignore the non-training branch code
        with paddle.no_grad():
            # data in batch have been downsampled already
            u = batch["u"]
            v = batch["v"]
            z_u = self.encoder(u)
            z_v = self.encoder(v)
            z_c = paddle.concat([z_u, z_v], axis=-1)

            if self.training:
                p = batch["p"]
                sdf = batch["sdf"]
                z_p = self.encoder(p)
                z_sdf = self.encoder(sdf)
                z_1 = paddle.concat([z_p, z_sdf], axis=-1)
                z_0 = paddle.randn(z_1.shape)  # (b, 200, 512)
                t = paddle.uniform(
                    [z_1.shape[0], *[1 for _ in range(z_1.ndim - 1)]],
                    min=0.0,
                    max=1.0,
                )
                z_t = t * (z_1 - z_0) + z_0
                v_t = z_1 - z_0
            else:
                raise

        # only training dit
        v_t_pred = self.dit(z_t, t.flatten(), z_c)

        if self.training:
            return {
                "v_t_err": v_t - v_t_pred,
            }
        else:
            raise


def train_fae(cfg: DictConfig):
    # Initialize model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    fae = FAE(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        encoder,
        decoder,
    )

    # init constraint
    train_dataloader_cfg = {
        "dataset": {
            "name": "TMTDataset",
            "input_keys": cfg.FAE.input_keys,
            "label_keys": cfg.FAE.output_keys,
            "data_path": cfg.DATA_PATH,
            "num_train": cfg.num_train,
            "mode": "train",
            "stage": cfg.stage,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 0,
    }

    sup_cst = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        loss=ppsci.loss.MSELoss(),
    )
    # NOTE: A hacky way to plug the batch parser into the constraint.
    sample_batch = next(iter(sup_cst.data_loader))
    _, h, w, _ = sample_batch.shape

    class PostProcessDataLoader:
        def __init__(self, dataloader, parser: BatchParser):
            self.dataloader = dataloader
            self.parser = parser

        def __iter__(self):
            for batch in self.dataloader:
                yield self.parser.random_query(batch)

        def __len__(self):
            return len(self.dataloader)

    sup_cst.data_loader = PostProcessDataLoader(
        sup_cst.data_loader,
        BatchParser(
            cfg.FAE.input_keys,
            cfg.FAE.output_keys,
            cfg.TRAIN.num_queries,
            h,
            w,
            cfg.TRAIN.solution,
        ),
    )
    sup_cst.data_iter = iter(sup_cst.data_loader)

    # reset epochs & iters_per_epoch
    cfg.TRAIN.iters_per_epoch = len(sup_cst.data_loader)
    logger.debug(f"cfg.TRAIN.iters_per_epoch = {cfg.TRAIN.iters_per_epoch}")
    cfg.TRAIN.epochs = cfg.TRAIN.steps // cfg.TRAIN.iters_per_epoch
    cfg.TRAIN.lr_scheduler.warmup_epoch /= cfg.TRAIN.iters_per_epoch
    logger.debug(
        f"cfg.TRAIN.lr_scheduler.warmup_epoch = {cfg.TRAIN.lr_scheduler.warmup_epoch}"
    )

    # Create learning rate schedule and optimizer
    lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(
        epochs=cfg.TRAIN.epochs,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch,
        **cfg.TRAIN.lr_scheduler,
    )()
    optimizer = ppsci.optimizer.AdamW(
        lr,
        beta1=cfg.TRAIN.beta1,
        beta2=cfg.TRAIN.beta2,
        epsilon=cfg.TRAIN.eps,
        weight_decay=cfg.TRAIN.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(cfg.TRAIN.clip_norm),
    )(fae)

    # init solver
    solver = ppsci.solver.Solver(
        fae,
        {"sup": sup_cst},
        optimizer=optimizer,
        cfg=cfg,
    )
    # train
    solver.train()


def train_diffusion(cfg: DictConfig):
    # Initialize fae model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    fae = FAE(
        cfg.FAE.input_keys,
        cfg.FAE.output_keys,
        encoder,
        decoder,
    )
    # Load pretrained fae params and freeze fae
    save_load.load_pretrain(
        fae,
        cfg.FAE.pretrained_model_path,
    )
    fae.freeze()

    # Initialize dit and wrap encoder&decoder&dit into one model for convenience
    dit = DiT(**cfg.DIT)
    model = ModelWrapper(
        encoder,
        decoder,
        dit,
    )

    # init constraint
    train_dataloader_cfg = {
        "dataset": {
            "name": "TMTDataset",
            "input_keys": cfg.FAE.input_keys,
            "label_keys": cfg.FAE.output_keys,
            "data_path": cfg.DATA_PATH,
            "num_train": cfg.num_train,
            "mode": "train",
            "stage": cfg.stage,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": False,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
    }
    sup_cst = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        loss=ppsci.loss.FunctionalLoss(
            lambda inp, label, weight: {"v_t": (inp["v_t_err"] ** 2).mean()}
        ),
    )

    # NOTE: A hacky way to plug the batch parser into the constraint.
    sample_batch = next(iter(sup_cst.data_loader))
    _, h, w, _ = sample_batch.shape

    class PostProcessDataLoader:
        def __init__(self, dataloader, parser: BatchParser):
            self.dataloader = dataloader
            self.parser = parser

        def __iter__(self):
            for batch in self.dataloader:
                yield self.parser.random_downsample(batch)

        def __len__(self):
            return len(self.dataloader)

    sup_cst.data_loader = PostProcessDataLoader(
        sup_cst.data_loader,
        BatchParser(
            cfg.FAE.input_keys,
            cfg.FAE.output_keys,
            None,
            h,
            w,
            cfg.TRAIN.solution,
        ),
    )
    sup_cst.data_iter = iter(sup_cst.data_loader)

    # reset epochs & iters_per_epoch
    cfg.TRAIN.iters_per_epoch = len(sup_cst.data_loader)
    logger.debug(f"cfg.TRAIN.iters_per_epoch = {cfg.TRAIN.iters_per_epoch}")
    cfg.TRAIN.epochs = cfg.TRAIN.steps // cfg.TRAIN.iters_per_epoch
    cfg.TRAIN.lr_scheduler.warmup_epoch /= cfg.TRAIN.iters_per_epoch
    logger.debug(
        f"cfg.TRAIN.lr_scheduler.warmup_epoch = {cfg.TRAIN.lr_scheduler.warmup_epoch}"
    )

    # Create learning rate schedule and optimizer
    lr = ppsci.optimizer.lr_scheduler.ExponentialDecay(
        epochs=cfg.TRAIN.epochs,
        iters_per_epoch=cfg.TRAIN.iters_per_epoch,
        **cfg.TRAIN.lr_scheduler,
    )()
    optimizer = ppsci.optimizer.AdamW(
        lr,
        beta1=cfg.TRAIN.beta1,
        beta2=cfg.TRAIN.beta2,
        epsilon=cfg.TRAIN.eps,
        weight_decay=cfg.TRAIN.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(cfg.TRAIN.clip_norm),
    )(dit)

    # init solver
    solver = ppsci.solver.Solver(
        model,
        {"sup": sup_cst},
        optimizer=optimizer,
        cfg=cfg,
    )
    # train
    solver.train()


@paddle.no_grad()
def evaluate(cfg: DictConfig):
    # Initialize encoder & decoder & dit model
    encoder = Encoder(**cfg.FAE.encoder)
    decoder = Decoder(**cfg.FAE.decoder)
    dit = DiT(**cfg.DIT)
    # wrap encoder&decoder&dit into one model for convenience
    model = ModelWrapper(
        encoder,
        decoder,
        dit,
    )
    save_load.load_pretrain(model, cfg.EVAL.pretrained_model_path)

    # init evaluate data
    eval_dataset = ppsci.data.dataset.TMTDataset(
        input_keys=cfg.FAE.input_keys,
        label_keys=cfg.FAE.output_keys,
        data_path=cfg.DATA_PATH,
        num_train=cfg.num_train,
        mode="test",
        stage=cfg.stage,
    )
    eval_loader = paddle.io.DataLoader(eval_dataset, batch_size=cfg.EVAL.batch_size)

    h, w = 200, 100
    x_coords = np.linspace(0, 1, h, dtype=dtype)
    y_coords = np.linspace(0, 1, w, dtype=dtype)
    x_coords, y_coords = np.meshgrid(x_coords, y_coords, indexing="ij")
    coords = np.hstack([x_coords.reshape(-1, 1), y_coords.reshape(-1, 1)], dtype=dtype)[
        None, ...
    ]

    noise_level = 0.1
    d = 1
    u_input_list = []
    v_input_list = []
    p_pred_list = []
    sdf_pred_list = []
    p_true_list = []
    sdf_true_list = []

    def sample_ode(
        z0: paddle.Tensor = None,
        c: paddle.Tensor = None,
        num_steps: int = None,
        use_conditioning: bool = False,
    ) -> paddle.Tensor:
        dt = 1 / num_steps
        traj = [z0]

        z = z0
        for i in tqdm(range(num_steps)):
            t = paddle.ones([z.shape[0]]) * i / num_steps
            if use_conditioning:
                pred = dit(z, t, c)
            else:
                pred = dit(z, t)
            z = z + pred * dt
            traj.append(z)
        return z, traj

    for batch in tqdm(eval_loader):
        u: paddle.Tensor = batch[:, ::d, ::d, 0:1]
        v: paddle.Tensor = batch[:, ::d, ::d, 1:2]
        p: paddle.Tensor = batch[..., 2:3]
        sdf: paddle.Tensor = batch[..., 3:4]

        # add noise to input field
        u = u + noise_level * paddle.randn(u.shape)
        v = v + noise_level * paddle.randn(v.shape)

        # compute condition latent code from given field
        z_u = encoder(u)
        z_v = encoder(v)
        z_c = paddle.concat([z_u, z_v], axis=-1)  # (b, l, 2c)

        # random sample z0 from standard normal distribution
        z0 = paddle.randn(shape=z_c.shape)

        # integrate ODE
        z1_new, _ = sample_ode(
            z0=z0,
            c=z_c,
            num_steps=cfg.EVAL.num_steps,
            use_conditioning=cfg.EVAL.use_conditioning,
        )

        # decode latent code to output field
        c_dim = z_c.shape[-1]
        z_p_new = z1_new[..., : c_dim // 2]
        z_sdf_new = z1_new[..., c_dim // 2 :]

        p_pred = decoder(z_p_new, coords)
        sdf_pred = decoder(z_sdf_new, coords)

        p_pred = p_pred.reshape([-1, h, w])
        sdf_pred = sdf_pred.reshape([-1, h, w])

        u_input_list.append(u)
        v_input_list.append(v)

        p_pred_list.append(p_pred)
        sdf_pred_list.append(sdf_pred)

        p_true_list.append(p)
        sdf_true_list.append(sdf)

    # Concatenate all results
    u_input = paddle.concat(u_input_list, axis=0).squeeze()
    v_input = paddle.concat(v_input_list, axis=0).squeeze()
    p_pred = paddle.concat(p_pred_list, axis=0)
    sdf_pred = paddle.concat(sdf_pred_list, axis=0)
    p_true = paddle.concat(p_true_list, axis=0).squeeze()
    sdf_true = paddle.concat(sdf_true_list, axis=0).squeeze()

    def compute_error(pred, y):
        return paddle.linalg.norm(
            pred.flatten(1) - y.flatten(1), axis=1, p="fro"
        ) / paddle.linalg.norm(y.flatten(1), axis=1, p="fro")

    # Compute errors
    error = compute_error(p_pred, p_true)
    logger.info(f"Mean relative p error: {paddle.mean(error).item():.4f}")
    logger.info(f"Max relative p error: {paddle.max(error).item():.4f}")
    logger.info(f"Min relative p error: {paddle.min(error).item():.4f}")
    logger.info(f"Std relative p error: {paddle.std(error, unbiased=True).item():.4f}")

    error = compute_error(sdf_pred, sdf_true)
    logger.info(f"Mean relative sdf error: {paddle.mean(error).item():.4f}")
    logger.info(f"Max relative sdf error: {paddle.max(error).item():.4f}")
    logger.info(f"Min relative sdf error: {paddle.min(error).item():.4f}")
    logger.info(
        f"Std relative sdf error: {paddle.std(error, unbiased=True).item():.4f}"
    )

    for k in range(u_input.shape[0]):
        if k >= 4:
            break

        fig = plt.figure(figsize=(20, 5))
        gs = gridspec.GridSpec(
            2, 5, width_ratios=[0.8, 1, 1, 1, 0.05], wspace=0.3, hspace=0.3
        )

        ax_input_u = fig.add_subplot(gs[0, 0])
        ax_input_u.set_title("Input U")
        im = ax_input_u.imshow(u_input[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_input_u)

        ax_input_v = fig.add_subplot(gs[1, 0])
        ax_input_v.set_title("Input V")
        im = ax_input_v.imshow(v_input[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_input_v)

        # Reference / Prediction / Error of P
        ax_ref = fig.add_subplot(gs[0, 1])
        ax_ref.set_title("Reference P")
        im = ax_ref.imshow(p_true[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_ref)

        ax_pred = fig.add_subplot(gs[0, 2])
        ax_pred.set_title("Prediction P")
        im = ax_pred.imshow(p_pred[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_pred)

        ax_err = fig.add_subplot(gs[0, 3])
        ax_err.set_title("Absolute Error P")
        im = ax_err.imshow(paddle.abs(p_pred[k].T - p_true[k].T), cmap="jet")
        plt.colorbar(im, ax=ax_err)

        # Reference / Prediction / Error of SDF
        ax_ref2 = fig.add_subplot(gs[1, 1])
        ax_ref2.set_title("Reference SDF")
        im = ax_ref2.imshow(sdf_true[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_ref2)

        ax_pred2 = fig.add_subplot(gs[1, 2])
        ax_pred2.set_title("Prediction SDF")
        im = ax_pred2.imshow(sdf_pred[k].T, cmap="jet")
        plt.colorbar(im, ax=ax_pred2)

        ax_err2 = fig.add_subplot(gs[1, 3])
        ax_err2.set_title("Absolute Error SDF")
        im = ax_err2.imshow(paddle.abs(sdf_pred[k].T - sdf_true[k].T), cmap="jet")
        plt.colorbar(im, ax=ax_err2)

        plt.tight_layout()
        plt.savefig(osp.join(cfg.output_dir, f"result_of_sample_{k}.png"), dpi=300)
        plt.close()


def export(cfg: DictConfig):
    raise NotImplementedError


def inference(cfg: DictConfig):
    raise NotImplementedError


@hydra.main(version_base=None, config_path="./conf", config_name="fae.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        if cfg.stage == "fae":
            train_fae(cfg)
        elif cfg.stage == "dit":
            train_diffusion(cfg)
        else:
            raise ValueError(
                f"cfg.stage should be 'fea', or 'dit', but got {cfg.stage}"
            )
    elif cfg.mode == "eval":
        evaluate(cfg)
    elif cfg.mode == "export":
        export(cfg)
    elif cfg.mode == "infer":
        inference(cfg)
    else:
        raise ValueError(
            f"cfg.mode should in ['train', 'eval', 'export', 'infer'], but got '{cfg.mode}'"
        )


if __name__ == "__main__":
    main()

5. Result Display

Evaluated on the test set, and some results are displayed:

result_of_sample_2.jpg

It can be seen that for functions \(p(x, coord | u,v)\) and \(sdf(x, coord | u,v)\), the model's prediction results are basically consistent with the reference results.

6. References