Source code for maxframe.tensor.misc.repeat

# Copyright 1999-2026 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from numbers import Integral
from typing import List

import numpy as np

from maxframe import opcodes
from maxframe.serialization.serializables import AnyField, Int32Field
from maxframe.tensor.core import Tensor, TensorOrder
from maxframe.tensor.datasource import tensor as astensor
from maxframe.tensor.misc.ravel import ravel
from maxframe.tensor.operators import TensorHasInput, TensorOperatorMixin
from maxframe.tensor.utils import broadcast_shape
from maxframe.typing_ import EntityType


class TensorRepeat(TensorHasInput, TensorOperatorMixin):
    _op_type_ = opcodes.REPEAT

    repeats = AnyField("repeats", default=None)
    axis = Int32Field("axis", default=None)

    def __init__(self, sparse=False, **kw):
        super().__init__(sparse=sparse, **kw)

    @classmethod
    def _set_inputs(cls, op: "TensorRepeat", inputs: List[EntityType]):
        super()._set_inputs(op, inputs)
        if len(inputs) > 1:
            op.repeats = inputs[1]

    def __call__(self, a, repeats):
        axis = self.axis
        a = astensor(a)
        if axis is None:
            a = ravel(a)

        ax = axis or 0

        if not isinstance(repeats, Integral):
            if not isinstance(repeats, Tensor):
                repeats = np.asarray(repeats)
                if repeats.size == 1:
                    repeats = int(repeats[0])
                    size = repeats * a.shape[axis or 0]
                elif a.shape[ax] == 1:
                    size = repeats = int(repeats.sum())
                else:
                    size = int(repeats.sum())
            else:
                size = np.nan
            if not isinstance(repeats, Integral):
                if repeats.ndim != 1:
                    raise ValueError("repeats should be 1-d tensor")
                broadcast_shape(repeats.shape, a.shape[ax : ax + 1])
        else:
            size = a.shape[axis or 0] * repeats

        shape = a.shape[:ax] + (size,) + a.shape[ax + 1 :]
        self.dtype = a.dtype
        self.sparse = a.issparse()

        inputs = [a]
        if isinstance(repeats, Tensor):
            inputs.append(repeats)
        else:
            self.repeats = repeats

        return self.new_tensor(inputs, shape, order=TensorOrder.C_ORDER)


[docs] def repeat(a, repeats, axis=None): """ Repeat elements of a tensor. Parameters ---------- a : array_like Input tensor. repeats : int or tensor of ints The number of repetitions for each element. `repeats` is broadcasted to fit the shape of the given axis. axis : int, optional The axis along which to repeat values. By default, use the flattened input tensor, and return a flat output tensor. Returns ------- repeated_tensor : Tensor Output array which has the same shape as `a`, except along the given axis. See Also -------- tile : Tile a tensor. Examples -------- >>> import maxframe.tensor as mt >>> mt.repeat(3, 4).execute() array([3, 3, 3, 3]) >>> x = mt.array([[1,2],[3,4]]) >>> mt.repeat(x, 2).execute() array([1, 1, 2, 2, 3, 3, 4, 4]) >>> mt.repeat(x, 3, axis=1).execute() array([[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]) >>> mt.repeat(x, [1, 2], axis=0).execute() array([[1, 2], [3, 4], [3, 4]]) """ op = TensorRepeat(axis=axis) return op(a, repeats)