跳转至

Data.batch_transform(批预处理) 模块

ppsci.data.process.batch_transform

FunctionalBatchTransform

Functional data transform class, which allows to use custom data transform function from given transform_func for special cases.

Parameters:

Name Type Description Default
transform_func Callable

Function of batch data transform.

required

Examples:

>>> import ppsci
>>> from typing import Tuple, Dict, Optional
>>> def batch_transform_func(
...     data_list: List[
...         Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]
...     ],
... ) -> List[Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]]:
...     input_dicts, label_dicts, weight_dicts = zip(*data_list)
...
...     for input_dict in input_dicts:
...         for key in input_dict:
...             input_dict[key] = input_dict[key] * 2
...
...     for label_dict in label_dicts:
...         for key in label_dict:
...             label_dict[key] = label_dict[key] + 1.0
...
...     return list(zip(input_dicts, label_dicts, weight_dicts))
...
>>> # Create a FunctionalBatchTransform object with the batch_transform_func function
>>> transform = ppsci.data.batch_transform.FunctionalBatchTransform(batch_transform_func)
>>> # Define some sample data, labels, and weights
>>> data = [({'x': 1}, {'y': 2}, None), ({'x': 11}, {'y': 22}, None)]
>>> transformed_data = transform(data)
>>> for tuple in transformed_data:
...     print(tuple)
({'x': 2}, {'y': 3.0}, None)
({'x': 22}, {'y': 23.0}, None)
Source code in ppsci/data/process/batch_transform/preprocess.py
class FunctionalBatchTransform:
    """Functional data transform class, which allows to use custom data transform function from given transform_func for special cases.

    Args:
        transform_func (Callable): Function of batch data transform.

    Examples:
        >>> import ppsci
        >>> from typing import Tuple, Dict, Optional
        >>> def batch_transform_func(
        ...     data_list: List[
        ...         Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]
        ...     ],
        ... ) -> List[Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[Dict[str, np.ndarray]]]]:
        ...     input_dicts, label_dicts, weight_dicts = zip(*data_list)
        ...
        ...     for input_dict in input_dicts:
        ...         for key in input_dict:
        ...             input_dict[key] = input_dict[key] * 2
        ...
        ...     for label_dict in label_dicts:
        ...         for key in label_dict:
        ...             label_dict[key] = label_dict[key] + 1.0
        ...
        ...     return list(zip(input_dicts, label_dicts, weight_dicts))
        ...
        >>> # Create a FunctionalBatchTransform object with the batch_transform_func function
        >>> transform = ppsci.data.batch_transform.FunctionalBatchTransform(batch_transform_func)
        >>> # Define some sample data, labels, and weights
        >>> data = [({'x': 1}, {'y': 2}, None), ({'x': 11}, {'y': 22}, None)]
        >>> transformed_data = transform(data)
        >>> for tuple in transformed_data:
        ...     print(tuple)
        ({'x': 2}, {'y': 3.0}, None)
        ({'x': 22}, {'y': 23.0}, None)
    """

    def __init__(
        self,
        transform_func: Callable[[List[Any]], List[Any]],
    ):
        self.transform_func = transform_func

    def __call__(
        self,
        data_list: List[Tuple[Optional[Dict[str, np.ndarray]]]],
    ) -> List[Tuple[Optional[Dict[str, np.ndarray]]]]:
        return self.transform_func(data_list)

build_batch_transforms(cfg, collate_fn)

Source code in ppsci/data/process/batch_transform/__init__.py
def build_batch_transforms(cfg, collate_fn: Optional[Callable]):
    cfg = copy.deepcopy(cfg)
    batch_transforms: Callable[[List[Any]], List[Any]] = build_transforms(cfg)
    if collate_fn is None:
        collate_fn = default_collate_fn

    def collate_fn_batch_transforms(batch: List[Any]):
        # apply batch transform on separate samples
        batch = batch_transforms(batch)

        # then collate separate samples into batched data
        return collate_fn(batch)

    return collate_fn_batch_transforms

default_collate_fn(batch)

Default_collate_fn for paddle dataloader.

NOTE: This default_collate_fn is different from official default_collate_fn which specially adapt case where sample is None and pgl.Graph.

ref: https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/io/dataloader/collate.py#L25

Parameters:

Name Type Description Default
batch List[Any]

Batch of samples to be collated.

required

Returns:

Name Type Description
Any Any

Collated batch data.

Source code in ppsci/data/process/batch_transform/__init__.py
def default_collate_fn(batch: List[Any]) -> Any:
    """Default_collate_fn for paddle dataloader.

    NOTE: This `default_collate_fn` is different from official `default_collate_fn`
    which specially adapt case where sample is `None` and `pgl.Graph`.

    ref: https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/io/dataloader/collate.py#L25

    Args:
        batch (List[Any]): Batch of samples to be collated.

    Returns:
        Any: Collated batch data.
    """
    sample = batch[0]
    if sample is None:
        return None
    elif isinstance(sample, np.ndarray):
        batch = np.stack(batch, axis=0)
        return batch
    elif isinstance(sample, (paddle.Tensor, paddle.framework.core.eager.Tensor)):
        return paddle.stack(batch, axis=0)
    elif isinstance(sample, numbers.Number):
        batch = np.array(batch)
        return batch
    elif isinstance(sample, (str, bytes)):
        return batch
    elif isinstance(sample, Mapping):
        return {key: default_collate_fn([d[key] for d in batch]) for key in sample}
    elif isinstance(sample, Sequence):
        sample_fields_num = len(sample)
        if not all(len(sample) == sample_fields_num for sample in iter(batch)):
            raise RuntimeError("Fields number not same among samples in a batch")
        return [default_collate_fn(fields) for fields in zip(*batch)]
    elif str(type(sample)) == "<class 'pgl.graph.Graph'>":
        # use str(type()) instead of isinstance() in case of pgl is not installed.
        graph = pgl.Graph(num_nodes=sample.num_nodes, edges=sample.edges)
        graph.x = np.concatenate([g.x for g in batch])
        graph.y = np.concatenate([g.y for g in batch])
        graph.edge_index = np.concatenate([g.edge_index for g in batch], axis=1)

        graph.edge_attr = np.concatenate([g.edge_attr for g in batch])
        graph.pos = np.concatenate([g.pos for g in batch])
        if hasattr(sample, "aoa"):
            graph.aoa = np.concatenate([g.aoa for g in batch])
        if hasattr(sample, "mach_or_reynolds"):
            graph.mach_or_reynolds = np.concatenate([g.mach_or_reynolds for g in batch])
        graph.tensor()
        graph.shape = [len(batch)]
        return graph
    elif (
        str(type(sample))
        == "<class 'ppsci.data.dataset.atmospheric_dataset.GraphGridMesh'>"
    ):
        graph = sample
        graph.tensor()
        graph.shape = [1]
        return graph
    raise TypeError(
        "batch data can only contains: paddle.Tensor, numpy.ndarray, "
        f"dict, list, number, None, pgl.Graph, GraphGridMesh, but got {type(sample)}"
    )