Source code for maxframe.tensor.misc.repeat

# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from numbers import Integral
from typing import List

import numpy as np

from ... import opcodes
from ...serialization.serializables import AnyField, Int32Field
from ...typing_ import EntityType
from ..core import Tensor, TensorOrder
from ..datasource import tensor as astensor
from ..operators import TensorHasInput, TensorOperatorMixin
from ..utils import broadcast_shape
from .ravel import ravel


class TensorRepeat(TensorHasInput, TensorOperatorMixin):
    _op_type_ = opcodes.REPEAT

    repeats = AnyField("repeats", default=None)
    axis = Int32Field("axis", default=None)

    def __init__(self, sparse=False, **kw):
        super().__init__(sparse=sparse, **kw)

    @classmethod
    def _set_inputs(cls, op: "TensorRepeat", inputs: List[EntityType]):
        super()._set_inputs(op, inputs)
        if len(inputs) > 1:
            op.repeats = inputs[1]

    def __call__(self, a, repeats):
        axis = self.axis
        a = astensor(a)
        if axis is None:
            a = ravel(a)

        ax = axis or 0

        if not isinstance(repeats, Integral):
            if not isinstance(repeats, Tensor):
                repeats = np.asarray(repeats)
                if repeats.size == 1:
                    repeats = int(repeats[0])
                    size = repeats * a.shape[axis or 0]
                elif a.shape[ax] == 1:
                    size = repeats = int(repeats.sum())
                else:
                    size = int(repeats.sum())
            else:
                size = np.nan
            if not isinstance(repeats, Integral):
                if repeats.ndim != 1:
                    raise ValueError("repeats should be 1-d tensor")
                broadcast_shape(repeats.shape, a.shape[ax : ax + 1])
        else:
            size = a.shape[axis or 0] * repeats

        shape = a.shape[:ax] + (size,) + a.shape[ax + 1 :]
        self.dtype = a.dtype
        self.sparse = a.issparse()

        inputs = [a]
        if isinstance(repeats, Tensor):
            inputs.append(repeats)
        else:
            self.repeats = repeats

        return self.new_tensor(inputs, shape, order=TensorOrder.C_ORDER)


[docs] def repeat(a, repeats, axis=None): """ Repeat elements of a tensor. Parameters ---------- a : array_like Input tensor. repeats : int or tensor of ints The number of repetitions for each element. `repeats` is broadcasted to fit the shape of the given axis. axis : int, optional The axis along which to repeat values. By default, use the flattened input tensor, and return a flat output tensor. Returns ------- repeated_tensor : Tensor Output array which has the same shape as `a`, except along the given axis. See Also -------- tile : Tile a tensor. Examples -------- >>> import maxframe.tensor as mt >>> mt.repeat(3, 4).execute() array([3, 3, 3, 3]) >>> x = mt.array([[1,2],[3,4]]) >>> mt.repeat(x, 2).execute() array([1, 1, 2, 2, 3, 3, 4, 4]) >>> mt.repeat(x, 3, axis=1).execute() array([[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]) >>> mt.repeat(x, [1, 2], axis=0).execute() array([[1, 2], [3, 4], [3, 4]]) """ op = TensorRepeat(axis=axis) return op(a, repeats)