跳转至

UNetFormer

Note

  1. 运行之前,建议快速了解一下数据集数据读取方式
  2. 将[Vaihingen数据集]下载到data目录中对应的子目录(如data/vaihingen/train_images)。
  3. 运行tools/vaihingen_patch_split.py处理原数据集,得到可供训练的数据。

文件数据集结构如下

airs
├── unetformer(code)
├── model_weights (save the model weights trained on ISPRS vaihingen)
├── fig_results (save the masks predicted by models)
├── lightning_logs (CSV format training logs)
├── data
   ├── vaihingen
      ├── train_images (original)
      ├── train_masks (original)
      ├── test_images (original)
      ├── test_masks (original)
      ├── test_masks_eroded (original)
      ├── train (processed)
      ├── test (processed)
# 将[Vaihingen数据集]下载到`data`目录中对应的子目录(如`data/vaihingen/train_images`)
# 创建训练数据集
python tools/vaihingen_patch_split.py --img-dir "data/vaihingen/train_images" --mask-dir "data/vaihingen/train_masks" --output-img-dir "data/vaihingen/train/images_1024" --output-mask-dir "data/vaihingen/train/masks_1024" --mode "train" --split-size 1024 --stride 512
# 创建测试数据集
python tools/vaihingen_patch_split.py --img-dir "data/vaihingen/test_images" --mask-dir "data/vaihingen/test_masks_eroded" --output-img-dir "data/vaihingen/test/images_1024" --output-mask-dir "data/vaihingen/test/masks_1024" --mode "val" --split-size 1024 --stride 1024 --eroded
# 创建masks_1024_rgb可视化数据集
python tools/vaihingen_patch_split.py --img-dir "data/vaihingen/test_images" --mask-dir "data/vaihingen/test_masks" --output-img-dir "data/vaihingen/test/images_1024" --output-mask-dir "data/vaihingen/test/masks_1024_rgb" --mode "val" --split-size 1024 --stride 1024 --gt
# 模型训练
python train_supervision.py -c config/vaihingen/unetformer.py
# 下载处理好的[Vaihingen测试数据集](https://paddle-org.bj.bcebos.com/paddlescience/datasets/unetformer/test.zip),并解压。
wget -c https://paddle-org.bj.bcebos.com/paddlescience/datasets/unetformer/test.zip -P ./data/vaihingen/
unzip -q ./data/vaihingen/test.zip -d data/vaihingen/
# 下载预训练模型文件
wget -c https://paddle-org.bj.bcebos.com/paddlescience/models/unetformer/unetformer-r18-512-crop-ms-e105_epoch0_best.pdparams -P ./model_weights/vaihingen/unetformer-r18-512-crop-ms-e105/
python vaihingen_test.py -c config/vaihingen/unetformer.py -o fig_results/vaihingen/unetformer --rgb

1. 背景简介

遥感城市场景图像的语义分割在众多实际应用中具有广泛需求,例如土地覆盖制图、城市变化检测、环境保护和经济评估等领域。在深度学习技术快速发展的推动下,卷积神经网络(CNN)多年来一直主导着语义分割领域。CNN采用分层特征表示方式,展现出强大的局部信息提取能力。然而卷积层的局部特性限制了网络捕获全局上下文信息的能力。近年来,作为计算机视觉领域的热点研究方向,Transformer架构在全局信息建模方面展现出巨大潜力,显著提升了图像分类、目标检测特别是语义分割等视觉相关任务的性能。

本文提出了一种基于Transformer的解码器架构,构建了类UNet结构的Transformer网络(UNetFormer),用于实时城市场景分割。为实现高效分割,UNetFormer选择轻量级ResNet18作为编码器,并在解码器中开发了高效的全局-局部注意力机制,以同时建模全局和局部信息。本文提出的基于Transformer的解码器与Swin Transformer编码器结合后,在Vaihingen数据集上也取得了当前最佳性能(91.3% F1分数和84.1% mIoU)。

2. 模型原理

本段落仅简单介绍模型原理,具体细节请阅读UNetFormer: A UNet-like Transformer for Efficient Semantic Segmentation of Remote Sensing Urban Scene Imagery

2.1 模型结构

UNetFormer是一种基于transformer的解码器的深度学习网络,下图显示了模型的整体结构。

UNetFormer1

  • ResBlock是resnet18网络的各个模块。

  • GLTB由全局-局部注意、MLP、两个batchnorm层和两个加和操作组成。

2.2 损失函数

判别器的损失函数由两部分组成,主损失函数\(\mathcal{L}_{\text {p }}\)为SoftCrossEntropyLoss交叉熵损失函数\(\mathcal{L}_{c e}\)和DiceLoss损失函数\(\mathcal{L}_{\text {dice }}\)。其表达式为:

\[ \mathcal{L}_{c e}=-\frac{1}{N} \sum_{n=1}^{N} \sum_{k=1}^{K} y_{k}^{(n)} \log \hat{y}_{k}^{(n)} \]
\[ \mathcal{L}_{\text {dice }}=1-\frac{2}{N} \sum_{n=1}^{N} \sum_{k=1}^{K} \frac{\hat{y}_{k}^{(n)} y_{k}^{(n)}}{\hat{y}_{k}^{(n)}+y_{k}^{(n)}} \]
\[ \mathcal{L}_{\text {p }}=\mathcal{L}_{c e}+\mathcal{L}_{\text {dice }} \]

其中N、K分别表示样本数量和类别数量。\(y^{(n)}\)\(\hat{y}^{(n)}\)表示标签的one-hot编码和相应的softmax输出,\(\mathrm{n} \in[1, \ldots, \mathrm{n}]\)

为了更好的结合,我们选择交叉熵函数作为辅助损失函数\({L}_{a u x}\),并且乘以系数\(\alpha\)总损失函数其表达式为:

\[ \mathcal{L}=\mathcal{L}_{p}+\alpha \times \mathcal{L}_{a u x} \]

其中,\(\alpha\)默认为0.4。

3. 模型构建

以下我们讲解释用PaddleScience构建UnetFormer的关键部分。

3.1 数据集介绍

数据集采用了ISPRS开源的Vaihingen数据集。

ISPRS提供了城市分类和三维建筑重建测试项目的两个最先进的机载图像数据集。该数据集采用了由高分辨率正交照片和相应的密集图像匹配技术产生的数字地表模型(DSM)。这两个数据集区域都涵盖了城市场景。Vaihingen是一个相对较小的村庄,有许多独立的建筑和小的多层建筑,该数据集包含33幅不同大小的遥感图像,每幅图像都是从一个更大的顶层正射影像图片提取的,图像选择的过程避免了出现没有数据的情况。顶层影像和DSM的空间分辨率为9 cm。遥感图像格式为8位TIFF文件,由近红外、红色和绿色3个波段组成。DSM是单波段的TIFF文件,灰度等级(对应于DSM高度)为32位浮点值编码。

image-vaihingen

每个数据集已手动分类为6个最常见的土地覆盖类别。

①不透水面 (RGB: 255, 255, 255)

②建筑物(RGB: 0, 0, 255)

③低矮植被 (RGB: 0, 255, 255)

④树木 (RGB: 0, 255, 0)

⑤汽车(RGB: 255, 255, 0)

⑥背景 (RGB: 255, 0, 0)

背景类包括水体和与其他已定义类别不同的物体(例如容器、网球场、游泳池),这些物体通常属于城市场景中的不感兴趣的语义对象。

3.2 构建dataset API

由于一份数据集由33个超大遥感图片组成组成。为了方便训练,我们自定义一个图像分割程序,将原始图片分割为1024×1024大小的可训练图片,程序代码具体信息在GeoSeg/tools/vaihingen_patch_split.py中可以看到。

3.3 模型构建

本案例的模型搭建代码如下

参数配置如下:

max_epoch = 105
ignore_index = len(CLASSES)
train_batch_size = 8
val_batch_size = 8
lr = 0.0006
weight_decay = 0.01
backbone_lr = 6e-05
backbone_weight_decay = 0.01
num_classes = len(CLASSES)
classes = CLASSES
weights_name = "unetformer-r18-512-crop-ms-e105"
weights_path = "model_weights/vaihingen/{}".format(weights_name)
test_weights_name = "unetformer-r18-512-crop-ms-e105_epoch0_best"
log_name = "vaihingen/{}".format(weights_name)
monitor = "val_F1"
monitor_mode = "max"
save_top_k = 1
save_last = True
check_val_every_n_epoch = 1
pretrained_ckpt_path = None
gpus = "auto"
resume_ckpt_path = None
net = UNetFormer(num_classes=num_classes)
loss = UnetFormerLoss(ignore_index=ignore_index)
use_aux_loss = True

3.4 loss函数

UNetFormer的损失函数由SoftCrossEntropyLoss交叉熵损失函数和DiceLoss损失函数组成

3.4.1 SoftCrossEntropyLoss

class SoftCrossEntropyLoss(paddle.nn.Layer):
    """
    Drop-in replacement for nn.CrossEntropyLoss with few additions:
    - Support of label smoothing
    """

    __constants__ = ["reduction", "ignore_index", "smooth_factor"]

    def __init__(
        self,
        reduction: str = "mean",
        smooth_factor: float = 0.0,
        ignore_index: Optional[int] = -100,
        dim=1,
    ):
        super().__init__()
        self.smooth_factor = smooth_factor
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.dim = dim

    def forward(self, input: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:
        log_prob = paddle.nn.functional.log_softmax(x=input, axis=self.dim)
        return label_smoothed_nll_loss(
            log_prob,
            target,
            epsilon=self.smooth_factor,
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            dim=self.dim,
        )

3.4.2 DiceLoss

class DiceLoss(paddle.nn.Layer):
    """
    Implementation of Dice loss for image segmentation task.
    It supports binary, multiclass and multilabel cases
    """

    def __init__(
        self,
        mode: str = "multiclass",
        classes: List[int] = None,
        log_loss=False,
        from_logits=True,
        smooth: float = 0.0,
        ignore_index=None,
        eps=1e-07,
    ):
        """

        :param mode: Metric mode {'binary', 'multiclass', 'multilabel'}
        :param classes: Optional list of classes that contribute in loss computation;
        By default, all channels are included.
        :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
        :param from_logits: If True assumes input is raw logits
        :param smooth:
        :param ignore_index: Label that indicates ignored pixels (does not contribute to loss)
        :param eps: Small epsilon for numerical stability
        """
        assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
        super(DiceLoss, self).__init__()
        self.mode = mode
        if classes is not None:
            assert (
                mode != BINARY_MODE
            ), "Masking classes is not supported with mode=binary"
            classes = to_tensor(classes, dtype="int64")
        self.classes = classes
        self.from_logits = from_logits
        self.smooth = smooth
        self.eps = eps
        self.ignore_index = ignore_index
        self.log_loss = log_loss

    def forward(self, y_pred: paddle.Tensor, y_true: paddle.Tensor) -> paddle.Tensor:
        """

        :param y_pred: NxCxHxW
        :param y_true: NxHxW
        :return: scalar
        """
        assert y_true.shape[0] == y_pred.shape[0]
        if self.from_logits:
            if self.mode == MULTICLASS_MODE:
                y_pred = paddle.nn.functional.log_softmax(y_pred, axis=1).exp()
            else:
                y_pred = paddle.nn.functional.log_sigmoid(x=y_pred).exp()
        bs = y_true.shape[0]
        num_classes = y_pred.shape[1]
        dims = 0, 2
        if self.mode == BINARY_MODE:
            y_true = y_true.view(bs, 1, -1)
            y_pred = y_pred.view(bs, 1, -1)
            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * paddle.cast(mask, dtype="float32")
                y_true = y_true * paddle.cast(mask, dtype="float32")
        if self.mode == MULTICLASS_MODE:
            y_true = y_true.view(bs, -1)
            y_pred = y_pred.view(bs, num_classes, -1)
            if self.ignore_index is not None:
                if self.ignore_index is not None:
                    mask = y_true != self.ignore_index
                    mask = paddle.cast(mask, dtype="float32")
                    y_pred = paddle.cast(
                        y_pred * mask.unsqueeze(axis=1), dtype="float32"
                    )
                    mask_float = paddle.cast(mask, dtype=y_true.dtype)
                    masked_y_true = (y_true * mask_float).astype("int64")
                    y_true = paddle.nn.functional.one_hot(
                        num_classes=num_classes, x=masked_y_true
                    ).astype("int64")
                    mask = paddle.cast(mask, dtype="int64")
                    y_true = y_true.transpose(perm=[0, 2, 1]) * mask.unsqueeze(axis=1)
            else:
                y_true = paddle.nn.functional.one_hot(
                    num_classes=num_classes, x=y_true
                ).astype("int64")
                y_true = y_true.transpose(perm=[0, 2, 1])
        if self.mode == MULTILABEL_MODE:
            y_true = y_true.view(bs, num_classes, -1)
            y_pred = y_pred.view(bs, num_classes, -1)
            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * paddle.cast(mask, dtype="float32")
                y_true = y_true * paddle.cast(mask, dtype="float32")
        scores = soft_dice_score(
            y_pred,
            y_true.astype(dtype=y_pred.dtype),
            smooth=self.smooth,
            eps=self.eps,
            dims=dims,
        )
        if self.log_loss:
            loss = -paddle.log(x=scores.clip(min=self.eps))
        else:
            loss = 1.0 - scores
        mask = y_true.sum(axis=dims) > 0
        loss *= mask.astype(loss.dtype)
        if self.classes is not None:
            loss = loss[self.classes]
        return loss.mean()

3.4.2 JointLoss

SoftCrossEntropyLoss和DiceLoss将使用JointLoss进行组合

class JointLoss(paddle.nn.Layer):
    """
    Wrap two loss functions into one. This class computes a weighted sum of two losses.
    """

    def __init__(
        self,
        first: paddle.nn.Layer,
        second: paddle.nn.Layer,
        first_weight=1.0,
        second_weight=1.0,
    ):
        super().__init__()
        self.first = WeightedLoss(first, first_weight)
        self.second = WeightedLoss(second, second_weight)

    def forward(self, *input):
        return self.first(*input) + self.second(*input)

3.4.2 UNetFormerLoss

class UnetFormerLoss(paddle.nn.Layer):
    def __init__(self, ignore_index=255):
        super().__init__()
        self.main_loss = JointLoss(
            SoftCrossEntropyLoss(smooth_factor=0.05, ignore_index=ignore_index),
            DiceLoss(smooth=0.05, ignore_index=ignore_index),
            1.0,
            1.0,
        )
        self.aux_loss = SoftCrossEntropyLoss(
            smooth_factor=0.05, ignore_index=ignore_index
        )

    def forward(self, logits, labels):
        if self.training and len(logits) == 2:
            logit_main, logit_aux = logits
            loss = self.main_loss(logit_main, labels) + 0.4 * self.aux_loss(
                logit_aux, labels
            )
        else:
            loss = self.main_loss(logits, labels)
        return loss

3.5 优化器构建

UNetFormer使用AdamW优化器,可直接调用paddle.optimizer.AdamW构建,代码如下:

layerwise_params = {
    "backbone.*": dict(lr=backbone_lr, weight_decay=backbone_weight_decay)
}
net_params = process_model_params(net, layerwise_params=layerwise_params)
optimizer = paddle.optimizer.AdamW(
    parameters=net_params, learning_rate=lr, weight_decay=weight_decay
)
tmp_lr = paddle.optimizer.lr.CosineAnnealingWarmRestarts(
    T_0=15, T_mult=2, learning_rate=optimizer.get_lr()
)
optimizer.set_lr_scheduler(tmp_lr)
lr_scheduler = tmp_lr

3.6 模型训练

    checkpoint_callback = ModelCheckpoint(
        save_top_k=config.save_top_k,
        monitor=config.monitor,
        save_last=config.save_last,
        mode=config.monitor_mode,
        dirpath=config.weights_path,
        filename=config.weights_name,
    )

    logger = CSVLogger("lightning_logs", name=config.log_name)

    model = Supervision_Train(config)

    if config.pretrained_ckpt_path:
        state_dict = paddle.load(config.pretrained_ckpt_path)
        model.set_state_dict(state_dict)

    paddle.set_device("gpu")

    optimizer, lr_scheduler = model.configure_optimizers()

    train_loader = model.train_dataloader()
    val_loader = model.val_dataloader()

    for epoch in range(config.max_epoch):
        print(f"Epoch {epoch+1}/{config.max_epoch}")
        model.train()
        train_losses = []
        for batch_idx, batch in enumerate(train_loader):
            output = model.training_step(batch, batch_idx)
            loss = output["loss"]
            train_losses.append(loss.item())
            loss.backward()
            optimizer.step()
            optimizer.clear_grad()
            if batch_idx % 10 == 0:
                print(
                    f"  Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}"
                )

        train_log = model.on_train_epoch_end()
        train_log["loss"] = np.mean(train_losses)
        if (epoch + 1) % config.check_val_every_n_epoch == 0:
            model.eval()
            val_losses = []
            for batch_idx, batch in enumerate(val_loader):
                output = model.validation_step(batch, batch_idx)
                val_losses.append(output["loss_val"].item())
            val_log = model.on_validation_epoch_end()
            val_log["loss_val"] = np.mean(val_losses)
            checkpoint_callback.on_validation_epoch_end(None, model, val_log)
            logger.log_metrics(epoch, train_log, val_log)
        if lr_scheduler:
            lr_scheduler.step()
        if config.resume_ckpt_path and epoch == 0:
            state = paddle.load(config.resume_ckpt_path)
            model.set_state_dict(state["model_state_dict"])
            optimizer.set_state_dict(state["optimizer_state_dict"])
            if lr_scheduler and "lr_scheduler_state_dict" in state:
                lr_scheduler.set_state_dict(state["lr_scheduler_state_dict"])
            print(f"Resumed training from checkpoint: {config.resume_ckpt_path}")


if __name__ == "__main__":
    main()

3.7 模型测试

def main():
    seed_everything(42)
    args = get_args()
    config = py2cfg(args.config_path)
    args.output_path.mkdir(exist_ok=True, parents=True)
    model = Supervision_Train.load_from_checkpoint(
        os.path.join(config.weights_path, config.test_weights_name + ".pdparams"),
        config=config,
    )
    model.eval()
    evaluator = Evaluator(num_class=config.num_classes)
    evaluator.reset()
    test_dataset = config.test_dataset
    test_loader = paddle.io.DataLoader(
        dataset=test_dataset,
        batch_size=2,
        num_workers=4,
        drop_last=False,
        shuffle=False,
    )

    results = []
    with paddle.no_grad():
        for batch in tqdm(test_loader):
            images = batch["img"]
            images = images.astype("float32")
            raw_predictions = model(images)

            raw_predictions = paddle.nn.functional.softmax(raw_predictions, axis=1)
            predictions = raw_predictions.argmax(axis=1)

            image_ids = batch["img_id"]
            masks_true = batch["gt_semantic_seg"]

            for i in range(len(image_ids)):
                mask = predictions[i].numpy()
                evaluator.add_batch(pre_image=mask, gt_image=masks_true[i].numpy())
                mask_name = image_ids[i]
                results.append((mask, str(args.output_path / mask_name), args.rgb))

    iou_per_class = evaluator.Intersection_over_Union()
    f1_per_class = evaluator.F1()
    OA = evaluator.OA()

    for class_name, class_iou, class_f1 in zip(
        config.classes, iou_per_class, f1_per_class
    ):
        print(f"F1_{class_name}: {class_f1:.4f}, IOU_{class_name}: {class_iou:.4f}")

    print(
        f"F1: {np.nanmean(f1_per_class[:-1]):.4f}, "
        f"mIOU: {np.nanmean(iou_per_class[:-1]):.4f}, "
        f"OA: {OA:.4f}"
    )

    t0 = time.time()
    with mp.Pool(processes=mp.cpu_count()) as pool:
        pool.map(img_writer, results)
    t1 = time.time()
    print(f"Images writing time: {t1 - t0:.2f} seconds")

4. 结果展示

使用Vaihingen数据集的训练结果。

F1 mIOU OA
0.9062 0.8318 0.9283

image-vaihingen1

image-vaihingen2

两张图片对比可以看出模型已经精确地分割出遥感图片中建筑、树木、汽车等物体的轮廓,并且很好地处理了重叠区域。

6. 参考文献