Source code for maxframe.dataframe.merge.concat

# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

import pandas as pd

from ... import opcodes
from ...core import ENTITY_TYPE, OutputType
from ...serialization.serializables import (
    AnyField,
    BoolField,
    FieldTypes,
    ListField,
    StringField,
)
from ...utils import lazy_import
from ..core import DataFrame, Series
from ..operators import SERIES_TYPE, DataFrameOperator, DataFrameOperatorMixin
from ..utils import build_empty_df, build_empty_series, parse_index, validate_axis

cudf = lazy_import("cudf")


class DataFrameConcat(DataFrameOperator, DataFrameOperatorMixin):
    _op_type_ = opcodes.CONCATENATE

    axis = AnyField("axis", default=None)
    join = StringField("join", default=None)
    ignore_index = BoolField("ignore_index", default=None)
    keys = ListField("keys", default=None)
    levels = ListField("levels", default=None)
    names = ListField("names", default=None)
    verify_integrity = BoolField("verify_integrity", default=None)
    sort = BoolField("sort", default=None)
    copy_ = BoolField("copy", default=None)

    def __init__(self, copy=None, output_types=None, **kw):
        super().__init__(copy_=copy, _output_types=output_types, **kw)

    @property
    def level(self):
        return self.levels

    @property
    def name(self):
        return self.names

    @classmethod
    def _concat_index(cls, df_or_series_list: Union[List[DataFrame], List[Series]]):
        concat_index = None
        all_indexes_have_value = all(
            input.index_value.has_value() for input in df_or_series_list
        )

        def _concat(prev_index: pd.Index, cur_index: pd.Index):
            if prev_index is None:
                return cur_index

            if (
                all_indexes_have_value
                and isinstance(prev_index, pd.RangeIndex)
                and isinstance(cur_index, pd.RangeIndex)
            ):
                # handle RangeIndex that append may generate huge amount of data
                # e.g. pd.RangeIndex(10_000) and pd.RangeIndex(10_000)
                # will generate a Int64Index full of data
                # for details see GH#1647
                prev_stop = prev_index.start + prev_index.size * prev_index.step
                cur_start = cur_index.start
                if prev_stop == cur_start and prev_index.step == cur_index.step:
                    # continuous RangeIndex, still return RangeIndex
                    return prev_index.append(cur_index)
                else:
                    # otherwise, return an empty index
                    return pd.Index([], dtype=prev_index.dtype)
            elif isinstance(prev_index, pd.RangeIndex):
                return pd.Index([], prev_index.dtype).append(cur_index)
            elif isinstance(cur_index, pd.RangeIndex):
                return prev_index.append(pd.Index([], cur_index.dtype))
            return prev_index.append(cur_index)

        for input in df_or_series_list:
            concat_index = _concat(concat_index, input.index_value.to_pandas())

        return concat_index

    def _call_series(self, objs):
        if self.axis == 0:
            row_length = 0
            for series in objs:
                row_length += series.shape[0]
            if self.ignore_index:
                idx_length = 0 if pd.isna(row_length) else row_length
                index_value = parse_index(pd.RangeIndex(idx_length))
            else:
                index = self._concat_index(objs)
                index_value = parse_index(index, objs)
            obj_names = {obj.name for obj in objs}
            return self.new_series(
                objs,
                shape=(row_length,),
                dtype=objs[0].dtype,
                index_value=index_value,
                name=objs[0].name if len(obj_names) == 1 else None,
            )
        else:
            col_length = 0
            columns = []
            dtypes = dict()
            undefined_name = 0
            for series in objs:
                if series.name is None:
                    dtypes[undefined_name] = series.dtype
                    undefined_name += 1
                    columns.append(undefined_name)
                else:
                    dtypes[series.name] = series.dtype
                    columns.append(series.name)
                col_length += 1
            if self.ignore_index or undefined_name == len(objs):
                columns_value = parse_index(pd.RangeIndex(col_length))
            else:
                columns_value = parse_index(pd.Index(columns), store_data=True)

            shape = (objs[0].shape[0], col_length)
            return self.new_dataframe(
                objs,
                shape=shape,
                dtypes=pd.Series(dtypes),
                index_value=objs[0].index_value,
                columns_value=columns_value,
            )

    def _call_dataframes(self, objs):
        if self.axis == 0:
            row_length = 0
            empty_dfs = []
            for df in objs:
                row_length += df.shape[0]
                if df.ndim == 2:
                    empty_dfs.append(build_empty_df(df.dtypes))
                else:
                    empty_dfs.append(build_empty_series(df.dtype, name=df.name))

            emtpy_result = pd.concat(empty_dfs, join=self.join, sort=self.sort)
            shape = (row_length, emtpy_result.shape[1])
            columns_value = parse_index(emtpy_result.columns, store_data=True)

            if self.join == "inner":
                objs = [o[list(emtpy_result.columns)] for o in objs]

            if self.ignore_index:
                idx_length = 0 if pd.isna(row_length) else row_length
                index_value = parse_index(pd.RangeIndex(idx_length))
            else:
                index = self._concat_index(objs)
                index_value = parse_index(index, objs)

            new_objs = []
            for obj in objs:
                if obj.ndim != 2:
                    # series
                    new_obj = obj.to_frame().reindex(columns=emtpy_result.dtypes.index)
                else:
                    # dataframe
                    if list(obj.dtypes.index) != list(emtpy_result.dtypes.index):
                        new_obj = obj.reindex(columns=emtpy_result.dtypes.index)
                    else:
                        new_obj = obj
                new_objs.append(new_obj)

            return self.new_dataframe(
                new_objs,
                shape=shape,
                dtypes=emtpy_result.dtypes,
                index_value=index_value,
                columns_value=columns_value,
            )
        else:
            col_length = 0
            empty_dfs = []
            for df in objs:
                if df.ndim == 2:
                    # DataFrame
                    col_length += df.shape[1]
                    empty_dfs.append(build_empty_df(df.dtypes))
                else:
                    # Series
                    col_length += 1
                    empty_dfs.append(build_empty_series(df.dtype, name=df.name))

            emtpy_result = pd.concat(empty_dfs, join=self.join, axis=1, sort=True)
            if self.ignore_index:
                columns_value = parse_index(pd.RangeIndex(col_length))
            else:
                columns_value = parse_index(
                    pd.Index(emtpy_result.columns), store_data=True
                )

            if self.ignore_index or len({o.index_value.key for o in objs}) == 1:
                new_objs = [obj if obj.ndim == 2 else obj.to_frame() for obj in objs]
            else:  # pragma: no cover
                raise NotImplementedError(
                    "Does not support concat dataframes which has different index"
                )

            shape = (objs[0].shape[0], col_length)
            return self.new_dataframe(
                new_objs,
                shape=shape,
                dtypes=emtpy_result.dtypes,
                index_value=objs[0].index_value,
                columns_value=columns_value,
            )

    def __call__(self, objs):
        if all(isinstance(obj, SERIES_TYPE) for obj in objs):
            self.output_types = [OutputType.series]
            return self._call_series(objs)
        else:
            self.output_types = [OutputType.dataframe]
            return self._call_dataframes(objs)


class GroupByConcat(DataFrameOperator, DataFrameOperatorMixin):
    _op_type_ = opcodes.GROUPBY_CONCAT

    _groups = ListField("groups", FieldTypes.key)
    _groupby_params = AnyField("groupby_params")

    def __init__(self, groups=None, groupby_params=None, output_types=None, **kw):
        super().__init__(
            _groups=groups,
            _groupby_params=groupby_params,
            _output_types=output_types,
            **kw
        )

    @property
    def groups(self):
        return self._groups

    @property
    def groupby_params(self):
        return self._groupby_params

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        inputs_iter = iter(self._inputs)

        new_groups = []
        for _ in self._groups:
            new_groups.append(next(inputs_iter))
        self._groups = new_groups

        if isinstance(self._groupby_params["by"], list):
            by = []
            for v in self._groupby_params["by"]:
                if isinstance(v, ENTITY_TYPE):
                    by.append(next(inputs_iter))
                else:
                    by.append(v)
            self._groupby_params["by"] = by

    @classmethod
    def execute(cls, ctx, op):
        input_data = [ctx[input.key] for input in op.groups]
        obj = pd.concat([d.obj for d in input_data])

        params = op.groupby_params.copy()
        if isinstance(params["by"], list):
            by = []
            for v in params["by"]:
                if isinstance(v, ENTITY_TYPE):
                    by.append(ctx[v.key])
                else:
                    by.append(v)
            params["by"] = by
        selection = params.pop("selection", None)

        result = obj.groupby(**params)
        if selection:
            result = result[selection]

        ctx[op.outputs[0].key] = result


[docs] def concat( objs, axis=0, join="outer", ignore_index=False, keys=None, levels=None, names=None, verify_integrity=False, sort=False, copy=True, ): if not isinstance(objs, (list, tuple)): # pragma: no cover raise TypeError( "first argument must be an iterable of dataframe or series objects" ) axis = validate_axis(axis) if isinstance(objs, dict): # pragma: no cover keys = objs.keys() objs = objs.values() if axis == 1 and join == "inner": # pragma: no cover raise NotImplementedError("inner join is not support when specify `axis=1`") if verify_integrity or sort or keys: # pragma: no cover raise NotImplementedError( "verify_integrity, sort, keys arguments are not supported now" ) op = DataFrameConcat( axis=axis, join=join, ignore_index=ignore_index, keys=keys, levels=levels, names=names, verify_integrity=verify_integrity, sort=sort, copy=copy, ) return op(objs)