跳转至

Climateformer

开始训练、评估前,请下载ERA5数据集文件。

开始评估前,请下载或训练生成预训练模型。

用于评估的ERA5数据集2018年数据已保存,可通过以下链接进行下载、评估: 2018.h5mean.ncstd.nc

python main.py
python main.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/climateformer/climateformer.pdparams"

1. 背景简介

长期气候预测主要涉及对未来几周乃至几个月内的天气变化进行预测。这类预测通常需要涵盖多个气象要素,如温度、湿度、风速等,这些要素对气象变化有着复杂的时空依赖关系。准确的气候预测对于防灾减灾、农业生产、航空航天等领域具有重要意义。传统的气象预测模型主要依赖于物理公式和数值天气预报(NWP),但随着深度学习的快速发展,基于数据驱动的模型逐渐展现出更强的预测能力。

Climateformer,正是一种面向长期气候预测的时空深度学习框架。该模型的设计目标是学习并模拟气象系统在长周期下的演化范式。其架构通常包含三大模块:编码器将海温、气压、风场等多源、多圈层的气候变量编码为一个能够表征当前“气候态”的全局向量。核心的演变模块(基于Transformer结构)则致力于捕捉这些气候态之间跨越数周甚至数月的长程时间依赖性。最终,解码器根据演变后的状态向量,预测未来多个周期内的平均关键气候指数。通过Climateformer,气候预报可以实现更加高效和精确的多要素预测,为气象服务提供更加可靠的数据支持。

2. 模型原理

本章节对 Climateformer 的模型原理进行简单地介绍。

2.1 编码器

该模块使用两层Transformer,提取空间特征更新节点特征:

ppsci/arch/climateformer.py
class Encoder(nn.Layer):
    def __init__(self, C_in: int, C_hid: int, N_S: int):
        super().__init__()
        strides = stride_generator(N_S)

        self.enc0 = ConvSC(C_in, C_hid, stride=strides[0])
        self.enc1 = OverlapPatchEmbed(
            img_size=256, patch_size=7, stride=4, in_chans=C_hid, embed_dim=C_hid
        )
        self.enc2 = Block(
            dim=C_hid,
            num_heads=1,
            mlp_ratio=4,
            qkv_bias=None,
            qk_scale=None,
            drop=0.0,
            attn_drop=0.0,
            drop_path=0.0,
            norm_layer=nn.LayerNorm,
            sr_ratio=8,
        )
        self.norm1 = nn.LayerNorm(C_hid)

    def forward(self, x):
        B = x.shape[0]
        latent = []
        x = self.enc0(x)
        latent.append(x)
        x, H, W = self.enc1(x)
        x = self.enc2(x, H, W)
        x = self.norm1(x)
        x = x.reshape([B, H, W, -1]).transpose(perm=[0, 3, 1, 2]).contiguous()
        latent.append(x)

        return latent

2.2 演变器

该模块使用两层Transformer,学习全局时间动态特性:

ppsci/arch/climateformer.py
class MidXnet(nn.Layer):
    def __init__(
        self,
        channel_in: int,
        channel_hid: int,
        N_T: int,
        incep_ker: Tuple[int, ...] = (3, 5, 7, 11),
        groups: int = 8,
    ):
        super().__init__()

        self.N_T = N_T
        dpr = [x.item() for x in np.linspace(0, 0.1, N_T)]
        enc_layers = []
        for i in range(N_T):
            enc_layers.append(
                Block(
                    dim=channel_in,
                    num_heads=4,
                    mlp_ratio=4,
                    qkv_bias=None,
                    qk_scale=None,
                    drop=0.0,
                    attn_drop=0.0,
                    drop_path=dpr[i],
                    norm_layer=nn.LayerNorm,
                    sr_ratio=8,
                )
            )

        self.enc = nn.Sequential(*enc_layers)

    def forward(self, x):
        B, T, C, H, W = x.shape
        # B TC H W

        x = x.reshape([B, T * C, H, W])
        # B HW TC
        x = x.flatten(2).transpose(perm=[0, 2, 1])

        # encoder
        z = x
        for i in range(self.N_T):
            z = self.enc[i](z, H, W)

        return z

2.3 解码器

该模块使用两层卷积,将时空表征解码为未来多气象要素:

ppsci/arch/climateformer.py
class Decoder(nn.Layer):
    def __init__(self, C_hid: int, C_out: int, N_S: int):
        super().__init__()
        strides = stride_generator(N_S, reverse=True)

        self.dec = nn.Sequential(
            *[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
            ConvSC(C_hid, C_hid, stride=strides[-1], transpose=True),
        )
        self.readout = nn.Conv2D(C_hid, C_out, 1)

    def forward(self, hid, enc1=None):
        for i in range(0, len(self.dec)):
            hid = self.dec[i](hid)
        Y = self.readout(hid)
        return Y

2.4 Climateformer模型结构

模型的总体结构如图所示:

climateformer-arch

Climateformer 网络模型

Climateformer模型首先使用特征嵌入层对输入信号(多气象要素的过去几个周平均时间帧)进行空间特征编码:

ppsci/arch/climateformer.py
# encoded
embed = self.enc(x)
_, C_4, H_4, W_4 = embed[-1].shape

然后模型利用演变器将学习空间特征的动态特性,预测未来几个周平均时间帧的气象特征:

ppsci/arch/climateformer.py
# translator
z = embed[-1].reshape([B, T, C_4, H_4, W_4])
hid = self.hid1(z)
hid = hid.transpose(perm=[0, 2, 1]).reshape([B, -1, H_4, W_4])

最后模型将时空动态特性与初始气象底层特征结合,使用两层卷积预测未来数周至数月的多气象要素周平均值:

ppsci/arch/climateformer.py
# decoded
y = self.dec(hid, embed[0])
y = y.reshape([B, T, self.num_classes, H, W])

3. 模型训练

3.1 数据集介绍

案例中使用了预处理的ERA5Climate数据集,属于ERA5再分析数据的一个子集。ERA5Climate包含了全球大气、陆地和海洋的多种变量,研究区域从东经 140° 到西经 70°,从北纬 55° 到赤道,空间分辨率为 0.25°。该数据集从2016年开始到2020年,每小时提供一次天气状况的估计,非常适合用于短中期多气象要素预测等任务。在实际应用过程中,时间间隔为一周,每帧选取为 7*24 小时内的周平均值。

数据集被保存为 T x C x H x W 的矩阵,记录了相应地点和时间的对应气象要素的值,其中 T 为时间序列长度,C代表通道维,案例中选取了3个不同气压层的温度、相对湿度、东向风速、北向风速等气象信息,H 和 W 代表按照经纬度划分后的矩阵的高度和宽度。根据年份,数据集按照 7:2:1 划分为训练集、验证集,和测试集。案例中预先计算了气象要素数据的均值与标准差,用于后续的正则化操作。

3.2 模型训练

3.2.1 模型构建

该案例基于 Climateformer 模型实现,用 PaddleScience 代码表示如下:

examples/climateformer/main.py
# set model
model = ppsci.arch.Climateformer(**cfg.MODEL)

3.2.2 约束器构建

本案例基于数据驱动的方法求解问题,因此需要使用 PaddleScience 内置的 SupervisedConstraint 构建监督约束器。在定义约束器之前,需要首先指定约束器中用于数据加载的各个参数。

训练集数据加载的代码如下:

examples/climateformer/main.py
# set train dataloader config
if not cfg.USE_SAMPLED_DATA:
    train_dataloader_cfg = {
        "dataset": {
            "name": "ERA5ClimateDataset",
            "file_path": cfg.TRAIN_FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.MODEL.output_keys,
            "size": (cfg.IMG_H, cfg.IMG_W),
            "years": cfg.TRAIN_YEARS,
        },
        "sampler": {
            "name": "BatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 4,
    }
else:
    train_dataloader_cfg = {
        "dataset": {
            "name": "ERA5SampledDataset",
            "file_path": cfg.TRAIN_FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.MODEL.output_keys,
            "years": cfg.TRAIN_YEARS,
        },
        "sampler": {
            "name": "DistributedBatchSampler",
            "drop_last": True,
            "shuffle": True,
        },
        "batch_size": cfg.TRAIN.batch_size,
        "num_workers": 4,
    }

定义监督约束的代码如下:

examples/climateformer/main.py
# set constraint
sup_constraint = ppsci.constraint.SupervisedConstraint(
    train_dataloader_cfg,
    ppsci.loss.MSELoss(),
    name="Sup",
)
constraint = {sup_constraint.name: sup_constraint}

3.2.3 评估器构建

本案例训练过程中会按照一定的训练轮数间隔,使用验证集评估当前模型的训练情况,需要使用 SupervisedValidator 构建评估器。

验证集数据加载的代码如下:

examples/climateformer/main.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5ClimateDataset",
        "file_path": cfg.VALID_FILE_PATH,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.MODEL.output_keys,
        "training": False,
        "size": (cfg.IMG_H, cfg.IMG_W),
        "years": cfg.EVAL_YEARS,
    },
    "batch_size": cfg.EVAL.batch_size,
}

定义监督评估器的代码如下:

examples/climateformer/main.py
# set validator
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.MSELoss(),
    metric={
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "MSE": ppsci.metric.MSE(keep_batch=True),
    },
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

3.2.4 学习率与优化器构建

本案例中学习率大小设置为 1e-3,优化器使用 Adam,用 PaddleScience 代码表示如下:

examples/climateformer/main.py
# init optimizer and lr scheduler
lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()

optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

3.2.5 模型训练

完成上述设置之后,只需要将上述实例化的对象按顺序传递给 ppsci.solver.Solver,然后启动训练。

examples/climateformer/main.py
# initialize solver
solver = ppsci.solver.Solver(
    model=model,
    constraint=constraint,
    output_dir=cfg.output_dir,
    optimizer=optimizer,
    epochs=cfg.TRAIN.epochs,
    iters_per_epoch=ITERS_PER_EPOCH,
    log_freq=cfg.log_freq,
    eval_during_train=cfg.TRAIN.eval_during_train,
    eval_freq=cfg.TRAIN.eval_freq,
    validator=validator,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
# train model
solver.train()

3.2.6 训练时评估

通过设置 ppsci.solver.Solver 中的 eval_during_train 参数,可以自动保存在验证集上效果最优的模型参数。

examples/climateformer/main.py
eval_during_train=cfg.TRAIN.eval_during_train,

3.3 评估模型

3.3.1 评估器构建

测试集数据加载的代码如下:

examples/climateformer/main.py
# set eval dataloader config
eval_dataloader_cfg = {
    "dataset": {
        "name": "ERA5ClimateDataset",
        "file_path": cfg.VALID_FILE_PATH,
        "input_keys": cfg.MODEL.input_keys,
        "label_keys": cfg.MODEL.output_keys,
        "training": False,
        "size": (cfg.IMG_H, cfg.IMG_W),
        "years": cfg.EVAL_YEARS,
    },
    "batch_size": cfg.EVAL.batch_size,
}

定义监督评估器的代码如下:

examples/climateformer/main.py
# set validator
sup_validator = ppsci.validate.SupervisedValidator(
    eval_dataloader_cfg,
    ppsci.loss.MSELoss(),
    metric={
        "MAE": ppsci.metric.MAE(keep_batch=True),
        "MSE": ppsci.metric.MSE(keep_batch=True),
    },
    name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

与验证集的 SupervisedValidator 相似,在这里使用的评价指标是 MAEMSE

3.3.2 加载模型并进行评估

设置预训练模型参数的加载路径并加载模型。

examples/climateformer/main.py
# set model
model = ppsci.arch.Climateformer(**cfg.MODEL)

实例化 ppsci.solver.Solver,然后启动评估。

examples/climateformer/main.py
# initialize solver
solver = ppsci.solver.Solver(
    model,
    output_dir=cfg.output_dir,
    log_freq=cfg.log_freq,
    validator=validator,
    pretrained_model_path=cfg.EVAL.pretrained_model_path,
    compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
    eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)
# evaluate
solver.eval()

4. 完整代码

数据集接口:

ppsci/data/dataset/era5climate_dataset.py
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import numbers
import os
import random
import time
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

try:
    import h5py
except ModuleNotFoundError:
    pass
try:
    import xarray as xr
except ModuleNotFoundError:
    pass
import numpy as np
import paddle
from paddle import io
from paddle import vision


class ERA5ClimateDataset(io.Dataset):
    """ERA5 dataset for multi-meteorological-element climate prediction (r, t, u, v).

    Args:
        file_path (str): Dataset path (contains .npy files in year folders).
        input_keys (Tuple[str, ...]): Input dict keys, e.g. ("input",).
        label_keys (Tuple[str, ...]): Label dict keys, e.g. ("output",).
        size (Tuple[int, int]): Crop size (height, width).
        weight_dict (Optional[Dict[str, float]]): Weight dictionary. Defaults to None.
        transforms (Optional[vision.Compose]): Optional transforms. Defaults to None.
        training (bool): If in training mode (2016-2018). Else validation mode (2019).
        stride (int): Stride for sampling. Defaults to 1.
        sq_length (int): Sequence length for input and output. Defaults to 6.
        years (Optional[List[str]]): List of years to load. Defaults to None (use default years).
    """

    batch_index: bool = False

    def __init__(
        self,
        file_path: str,
        input_keys: Tuple[str, ...],
        label_keys: Tuple[str, ...],
        size: Tuple[int, ...],
        weight_dict: Optional[Dict[str, float]] = None,
        transforms: Optional[vision.Compose] = None,
        training: bool = True,
        stride: int = 1,
        sq_length: int = 6,
        years: Optional[List[str]] = None,
    ):
        super().__init__()
        self.file_path = file_path
        self.input_keys = input_keys
        self.label_keys = label_keys
        self.size = size
        self.training = training
        self.sq_length = sq_length
        self.transforms = transforms
        self.stride = stride
        self.group_size = 24 * 7  # 168 hours per week

        mean_file_path = os.path.join(self.file_path, "mean.nc")
        std_file_path = os.path.join(self.file_path, "std.nc")

        mean_ds = xr.open_dataset(mean_file_path)
        std_ds = xr.open_dataset(std_file_path)

        self.mean = mean_ds["mean"].values.reshape(-1, 1, 1)
        self.std = std_ds["std"].values.reshape(-1, 1, 1)

        print("Start loading all hourly data from the HDF5 file...")
        start_time = time.time()

        if self.training:
            years = ["2016", "2017", "2018"] if years is None else years
        else:
            years = ["2019"] if years is None else years

        all_hourly_data = []
        for year in years:
            h5_filepath = os.path.join(self.file_path, f"{year}.h5")
            if not os.path.exists(h5_filepath):
                raise FileNotFoundError(f"h5 file not found: {h5_filepath}")

            print(f"Loading {h5_filepath}...")
            with h5py.File(h5_filepath, "r") as hf:
                all_hourly_data.append(hf["data"][:])

        self.data_hourly = np.concatenate(all_hourly_data, axis=0)

        end_time = time.time()
        print("Data loaded!")
        print(
            f"Total hours: {self.data_hourly.shape[0]}, Shape: {self.data_hourly.shape}"
        )
        print(f"Estimated memory usage: {self.data_hourly.nbytes / 1e9:.2f} GB")
        print(f"Loading time: {end_time - start_time:.2f} seconds.")

        self.weight_dict = {} if weight_dict is None else weight_dict
        if weight_dict is not None:
            self.weight_dict = {key: 1.0 for key in self.label_keys}
            self.weight_dict.update(weight_dict)

    def __len__(self):
        group_size = 24 * 7  # 7 days of hourly data
        span = 2 * self.sq_length * group_size
        return self.data_hourly.shape[0] - span + 1

    def __getitem__(self, global_idx):
        x_start_hour = global_idx
        x_end_hour = x_start_hour + self.sq_length * self.group_size

        y_start_hour = x_end_hour
        y_end_hour = y_start_hour + self.sq_length * self.group_size

        x_hourly = self.data_hourly[x_start_hour:x_end_hour]
        y_hourly = self.data_hourly[y_start_hour:y_end_hour]

        x_weekly_groups = x_hourly.reshape(
            self.sq_length, self.group_size, *x_hourly.shape[1:]
        )
        y_weekly_groups = y_hourly.reshape(
            self.sq_length, self.group_size, *y_hourly.shape[1:]
        )

        x = np.mean(x_weekly_groups, axis=1)  # x.shape: (sq_length, 12, H, W)
        y = np.mean(y_weekly_groups, axis=1)  # y.shape: (sq_length, 12, H, W)

        x = (x - self.mean) / self.std
        y = (y - self.mean) / self.std

        x, y = self._random_crop(x, y)

        input_item = {self.input_keys[0]: x.astype(np.float32)}
        label_item = {self.label_keys[0]: y.astype(np.float32)}

        weight_shape = [1] * len(next(iter(label_item.values())).shape)
        weight_item = {
            key: np.full(weight_shape, value, paddle.get_default_dtype())
            for key, value in self.weight_dict.items()
        }

        if self.transforms is not None:
            input_item, label_item, weight_item = self.transforms(
                input_item, label_item, weight_item
            )

        return input_item, label_item, weight_item

    def _random_crop(self, x, y):
        if isinstance(self.size, numbers.Number):
            self.size = (int(self.size), int(self.size))

        th, tw = self.size
        h, w = y.shape[-2], y.shape[-1]  # Get the original height and width from y

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)

        x_cropped = x[..., y1 : y1 + th, x1 : x1 + tw]
        y_cropped = y[..., y1 : y1 + th, x1 : x1 + tw]

        return x_cropped, y_cropped

模型结构:

ppsci/arch/climateformer.py
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
from typing import Optional
from typing import Tuple

import numpy as np
from paddle import nn

from ppsci.arch import base


def stride_generator(N, reverse=False):
    strides = [1, 2] * 10
    if reverse:
        return list(reversed(strides[:N]))
    else:
        return strides[:N]


class ConvSC(nn.Layer):
    def __init__(self, C_in: int, C_out: int, stride: int, transpose: bool = False):
        super(ConvSC, self).__init__()
        if stride == 1:
            transpose = False
        if not transpose:
            self.conv = nn.Conv2D(
                C_in,
                C_out,
                kernel_size=3,
                stride=stride,
                padding=1,
                weight_attr=nn.initializer.KaimingNormal(),
            )
        else:
            self.conv = nn.Conv2DTranspose(
                C_in,
                C_out,
                kernel_size=3,
                stride=stride,
                padding=1,
                output_padding=stride // 2,
                weight_attr=nn.initializer.KaimingNormal(),
            )
        self.norm = nn.GroupNorm(2, C_out)
        self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        y = self.conv(x)
        y = self.act(self.norm(y))
        return y


class OverlapPatchEmbed(nn.Layer):
    """Image to Patch Embedding"""

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 7,
        stride: int = 4,
        in_chans: int = 3,
        embed_dim: int = 768,
    ):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2D(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=(patch_size[0] // 2, patch_size[1] // 2),
        )
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(perm=[0, 2, 1])
        x = self.norm(x)

        return x, H, W


class DWConv(nn.Layer):
    def __init__(self, dim: int = 768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2D(dim, dim, 3, 1, 1, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(perm=[0, 2, 1]).reshape([B, C, H, W])
        x = self.dwconv(x)
        x = x.flatten(2).transpose(perm=[0, 2, 1])

        return x


class Mlp(nn.Layer):
    def __init__(
        self,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        act_layer: nn.Layer = nn.GELU,
        drop: float = 0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Layer):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: Optional[int] = None,
        qk_scale: Optional[int] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        sr_ratio: float = 1.0,
    ):
        super().__init__()
        assert (
            dim % num_heads == 0
        ), f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.q = nn.Linear(dim, dim, bias_attr=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(axis=-1)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2D(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = (
            self.q(x)
            .reshape([B, N, self.num_heads, C // self.num_heads])
            .transpose(perm=[0, 2, 1, 3])
        )

        if self.sr_ratio > 1:
            x_ = x.transpose(perm=[0, 2, 1]).reshape([B, C, H, W])
            x_ = self.sr(x_).reshape([B, C, -1]).transpose(perm=[0, 2, 1])
            x_ = self.norm(x_)
            kv = (
                self.kv(x_)
                .reshape([B, -1, 2, self.num_heads, C // self.num_heads])
                .transpose(perm=[2, 0, 3, 1, 4])
            )
        else:
            kv = (
                self.kv(x)
                .reshape([B, -1, 2, self.num_heads, C // self.num_heads])
                .transpose(perm=[2, 0, 3, 1, 4])
            )
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(perm=[0, 1, 3, 2])) * self.scale
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(perm=[0, 2, 1, 3]).reshape([B, N, C])
        x = self.norm(x)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x


class Block(nn.Layer):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: Optional[int] = None,
        qk_scale: Optional[int] = None,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        act_layer: nn.Layer = nn.GELU,
        norm_layer: nn.Layer = nn.LayerNorm,
        sr_ratio: float = 1.0,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            sr_ratio=sr_ratio,
        )
        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

        return x


class Encoder(nn.Layer):
    def __init__(self, C_in: int, C_hid: int, N_S: int):
        super().__init__()
        strides = stride_generator(N_S)

        self.enc0 = ConvSC(C_in, C_hid, stride=strides[0])
        self.enc1 = OverlapPatchEmbed(
            img_size=256, patch_size=7, stride=4, in_chans=C_hid, embed_dim=C_hid
        )
        self.enc2 = Block(
            dim=C_hid,
            num_heads=1,
            mlp_ratio=4,
            qkv_bias=None,
            qk_scale=None,
            drop=0.0,
            attn_drop=0.0,
            drop_path=0.0,
            norm_layer=nn.LayerNorm,
            sr_ratio=8,
        )
        self.norm1 = nn.LayerNorm(C_hid)

    def forward(self, x):
        B = x.shape[0]
        latent = []
        x = self.enc0(x)
        latent.append(x)
        x, H, W = self.enc1(x)
        x = self.enc2(x, H, W)
        x = self.norm1(x)
        x = x.reshape([B, H, W, -1]).transpose(perm=[0, 3, 1, 2]).contiguous()
        latent.append(x)

        return latent


class MidXnet(nn.Layer):
    def __init__(
        self,
        channel_in: int,
        channel_hid: int,
        N_T: int,
        incep_ker: Tuple[int, ...] = (3, 5, 7, 11),
        groups: int = 8,
    ):
        super().__init__()

        self.N_T = N_T
        dpr = [x.item() for x in np.linspace(0, 0.1, N_T)]
        enc_layers = []
        for i in range(N_T):
            enc_layers.append(
                Block(
                    dim=channel_in,
                    num_heads=4,
                    mlp_ratio=4,
                    qkv_bias=None,
                    qk_scale=None,
                    drop=0.0,
                    attn_drop=0.0,
                    drop_path=dpr[i],
                    norm_layer=nn.LayerNorm,
                    sr_ratio=8,
                )
            )

        self.enc = nn.Sequential(*enc_layers)

    def forward(self, x):
        B, T, C, H, W = x.shape
        # B TC H W

        x = x.reshape([B, T * C, H, W])
        # B HW TC
        x = x.flatten(2).transpose(perm=[0, 2, 1])

        # encoder
        z = x
        for i in range(self.N_T):
            z = self.enc[i](z, H, W)

        return z


# MultiDecoder
class Decoder(nn.Layer):
    def __init__(self, C_hid: int, C_out: int, N_S: int):
        super().__init__()
        strides = stride_generator(N_S, reverse=True)

        self.dec = nn.Sequential(
            *[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
            ConvSC(C_hid, C_hid, stride=strides[-1], transpose=True),
        )
        self.readout = nn.Conv2D(C_hid, C_out, 1)

    def forward(self, hid, enc1=None):
        for i in range(0, len(self.dec)):
            hid = self.dec[i](hid)
        Y = self.readout(hid)
        return Y


class Climateformer(base.Arch):
    """
    Climateformer is a class that represents a Spatial-Temporal Transformer model designed for climate prediction with multiple meteorological variables.

    Args:
        input_keys (Tuple[str, ...]): A tuple of input keys.
        output_keys (Tuple[str, ...]): A tuple of output keys.
        shape_in (Tuple[int, ...]): The shape of the input data (T, C, H, W), where
            T is the number of time steps, C is the number of channels,
            H and W are the spatial dimensions.
        hid_S (int): The number of hidden channels in the spatial encoder.
        hid_T (int): The number of hidden units in the temporal encoder.
        N_S (int): The number of spatial transformer layers.
        N_T (int): The number of temporal transformer layers.
        incep_ker (Tuple[int, ...]): The kernel sizes used in the inception block.
        groups (int): The number of groups for grouped convolutions.
        num_classes (int): The number of predicted meteorological variables.

    Examples:
        >>> import paddle
        >>> import ppsci
        >>> model = ppsci.arch.Climateformer(
        ...     input_keys=("input",),
        ...     output_keys=("output",),
        ...     shape_in=(6, 12, 192, 256),
        ...     hid_S=64,
        ...     hid_T=256,
        ...     N_S=4,
        ...     N_T=4,
        ...     incep_ker=(3, 5, 7, 11),
        ...     groups=8,
        ...     num_classes=4,
        ... )
        >>> input_dict = {"input": paddle.rand([8, 6, 12, 192, 256])}
        >>> output_dict = model(input_dict)
        >>> print(output_dict["output"].shape)
        [8, 6, 12, 192, 256]
    """

    def __init__(
        self,
        input_keys: Tuple[str, ...],
        output_keys: Tuple[str, ...],
        shape_in: Tuple[int, ...],
        hid_S: int = 64,
        hid_T: int = 256,
        N_S: int = 4,
        N_T: int = 4,
        incep_ker: Tuple[int, ...] = (3, 5, 7, 11),
        groups: int = 8,
        num_classes: int = 12,
    ):
        super().__init__()
        self.input_keys = input_keys
        self.output_keys = output_keys
        self.num_classes = num_classes

        T, C, H, W = shape_in
        self.enc = Encoder(C, hid_S, N_S)
        self.hid1 = MidXnet(T * hid_S, hid_T // 2, N_T, incep_ker, groups)
        self.dec = Decoder(T * hid_S, T * self.num_classes, N_S)

    def forward(self, x):
        if self._input_transform is not None:
            x = self._input_transform(x)

        x = self.concat_to_tensor(x, self.input_keys)

        B, T, C, H, W = x.shape
        x = x.reshape([B * T, C, H, W])

        # encoded
        embed = self.enc(x)
        _, C_4, H_4, W_4 = embed[-1].shape

        # translator
        z = embed[-1].reshape([B, T, C_4, H_4, W_4])
        hid = self.hid1(z)
        hid = hid.transpose(perm=[0, 2, 1]).reshape([B, -1, H_4, W_4])

        # decoded
        y = self.dec(hid, embed[0])
        y = y.reshape([B, T, self.num_classes, H, W])

        y = self.split_to_dict(y, self.output_keys)
        if self._output_transform is not None:
            y = self._output_transform(x, y)

        return y  # {self.output_keys[0]: Y}

模型训练:

examples/climateformer/main.py
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hydra
import utils as utils
from omegaconf import DictConfig

import ppsci


def train(cfg: DictConfig):
    # set train dataloader config
    if not cfg.USE_SAMPLED_DATA:
        train_dataloader_cfg = {
            "dataset": {
                "name": "ERA5ClimateDataset",
                "file_path": cfg.TRAIN_FILE_PATH,
                "input_keys": cfg.MODEL.input_keys,
                "label_keys": cfg.MODEL.output_keys,
                "size": (cfg.IMG_H, cfg.IMG_W),
                "years": cfg.TRAIN_YEARS,
            },
            "sampler": {
                "name": "BatchSampler",
                "drop_last": True,
                "shuffle": True,
            },
            "batch_size": cfg.TRAIN.batch_size,
            "num_workers": 4,
        }
    else:
        train_dataloader_cfg = {
            "dataset": {
                "name": "ERA5SampledDataset",
                "file_path": cfg.TRAIN_FILE_PATH,
                "input_keys": cfg.MODEL.input_keys,
                "label_keys": cfg.MODEL.output_keys,
                "years": cfg.TRAIN_YEARS,
            },
            "sampler": {
                "name": "DistributedBatchSampler",
                "drop_last": True,
                "shuffle": True,
            },
            "batch_size": cfg.TRAIN.batch_size,
            "num_workers": 4,
        }

    # set constraint
    sup_constraint = ppsci.constraint.SupervisedConstraint(
        train_dataloader_cfg,
        ppsci.loss.MSELoss(),
        name="Sup",
    )
    constraint = {sup_constraint.name: sup_constraint}

    # set iters_per_epoch by dataloader length
    ITERS_PER_EPOCH = len(sup_constraint.data_loader)

    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5ClimateDataset",
            "file_path": cfg.VALID_FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.MODEL.output_keys,
            "training": False,
            "size": (cfg.IMG_H, cfg.IMG_W),
            "years": cfg.EVAL_YEARS,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.MSELoss(),
        metric={
            "MAE": ppsci.metric.MAE(keep_batch=True),
            "MSE": ppsci.metric.MSE(keep_batch=True),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    model = ppsci.arch.Climateformer(**cfg.MODEL)

    # init optimizer and lr scheduler
    lr_scheduler_cfg = dict(cfg.TRAIN.lr_scheduler)
    lr_scheduler_cfg.update({"iters_per_epoch": ITERS_PER_EPOCH})
    lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine(**lr_scheduler_cfg)()

    optimizer = ppsci.optimizer.Adam(lr_scheduler)(model)

    # initialize solver
    solver = ppsci.solver.Solver(
        model=model,
        constraint=constraint,
        output_dir=cfg.output_dir,
        optimizer=optimizer,
        epochs=cfg.TRAIN.epochs,
        iters_per_epoch=ITERS_PER_EPOCH,
        log_freq=cfg.log_freq,
        eval_during_train=cfg.TRAIN.eval_during_train,
        eval_freq=cfg.TRAIN.eval_freq,
        validator=validator,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # train model
    solver.train()
    # evaluate after finished training
    solver.eval()


def evaluate(cfg: DictConfig):
    # set eval dataloader config
    eval_dataloader_cfg = {
        "dataset": {
            "name": "ERA5ClimateDataset",
            "file_path": cfg.VALID_FILE_PATH,
            "input_keys": cfg.MODEL.input_keys,
            "label_keys": cfg.MODEL.output_keys,
            "training": False,
            "size": (cfg.IMG_H, cfg.IMG_W),
            "years": cfg.EVAL_YEARS,
        },
        "batch_size": cfg.EVAL.batch_size,
    }

    # set validator
    sup_validator = ppsci.validate.SupervisedValidator(
        eval_dataloader_cfg,
        ppsci.loss.MSELoss(),
        metric={
            "MAE": ppsci.metric.MAE(keep_batch=True),
            "MSE": ppsci.metric.MSE(keep_batch=True),
        },
        name="Sup_Validator",
    )
    validator = {sup_validator.name: sup_validator}

    # set model
    model = ppsci.arch.Climateformer(**cfg.MODEL)

    # initialize solver
    solver = ppsci.solver.Solver(
        model,
        output_dir=cfg.output_dir,
        log_freq=cfg.log_freq,
        validator=validator,
        pretrained_model_path=cfg.EVAL.pretrained_model_path,
        compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
        eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
    )
    # evaluate
    solver.eval()


@hydra.main(version_base=None, config_path="./conf", config_name="climateformer.yaml")
def main(cfg: DictConfig):
    if cfg.mode == "train":
        train(cfg)
    elif cfg.mode == "eval":
        evaluate(cfg)
    else:
        raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")


if __name__ == "__main__":
    main()

配置文件:

examples/climateformer/conf/climateformer.yaml
defaults:
  - ppsci_default
  - TRAIN: train_default
  - TRAIN/ema: ema_default
  - TRAIN/swa: swa_default
  - EVAL: eval_default
  - INFER: infer_default
  - hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
  - _self_

hydra:
  run:
    # dynamic output directory according to running time and override name
    dir: outputs_climateformer/${now:%Y-%m-%d}/${now:%H-%M-%S}
  job:
    name: ${mode} # name of logfile
    chdir: false # keep current working directory unchanged
  callbacks:
    init_callback:
      _target_: ppsci.utils.callbacks.InitCallback
  sweep:
    # output directory for multirun
    dir: ${hydra.run.dir}
    subdir: ./

# general settings
mode: train # running mode: train/eval
seed: 1024
output_dir: ${hydra:run.dir}
log_freq: 20 # 20

# set training hyper-parameters
SQ_LEN: 6
IMG_H: 192
IMG_W: 256
USE_SAMPLED_DATA: false

# data years settings
TRAIN_YEARS: ["2018"]
EVAL_YEARS: ["2018"]

# set train data path
TRAIN_FILE_PATH: /data/ERA5/
DATA_MEAN_PATH: /data/ERA5/mean.nc
DATA_STD_PATH: /data/ERA5/std.nc

# set evaluate data path
VALID_FILE_PATH: /data/ERA5/

# model settings
MODEL:
  input_keys: ["input"]
  output_keys: ["output"]
  shape_in:
    - 6
    - 12
    - ${IMG_H}
    - ${IMG_W}

# training settings
TRAIN:
  epochs: 20  # 150
  save_freq: 5  # 20
  eval_during_train: true
  eval_freq: 5  # 20
  lr_scheduler:
    epochs: ${TRAIN.epochs}
    learning_rate: 0.001
    by_epoch: true
  batch_size: 8 # 16
  pretrained_model_path: null
  checkpoint_path: null

# evaluation settings
EVAL:
  pretrained_model_path: null
  compute_metric_by_batch: true
  eval_with_no_grad: true
  batch_size: 8 # 16

5. 结果展示

下图展示了Climateformer模型在1000 hPa等气压层温度预测任务中的预测结果与真值对比。横轴表示不同的预测时间步,时间间隔为1周,每次模型预测未来6周的周平均值。

result

Climateformer模型预测结果("Pred")与真值结果("GT")