# Copyright 1999-2026 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from numbers import Integral
from typing import List
import numpy as np
from maxframe import opcodes
from maxframe.serialization.serializables import AnyField, Int32Field
from maxframe.tensor.core import Tensor, TensorOrder
from maxframe.tensor.datasource import tensor as astensor
from maxframe.tensor.misc.ravel import ravel
from maxframe.tensor.operators import TensorHasInput, TensorOperatorMixin
from maxframe.tensor.utils import broadcast_shape
from maxframe.typing_ import EntityType
class TensorRepeat(TensorHasInput, TensorOperatorMixin):
_op_type_ = opcodes.REPEAT
repeats = AnyField("repeats", default=None)
axis = Int32Field("axis", default=None)
def __init__(self, sparse=False, **kw):
super().__init__(sparse=sparse, **kw)
@classmethod
def _set_inputs(cls, op: "TensorRepeat", inputs: List[EntityType]):
super()._set_inputs(op, inputs)
if len(inputs) > 1:
op.repeats = inputs[1]
def __call__(self, a, repeats):
axis = self.axis
a = astensor(a)
if axis is None:
a = ravel(a)
ax = axis or 0
if not isinstance(repeats, Integral):
if not isinstance(repeats, Tensor):
repeats = np.asarray(repeats)
if repeats.size == 1:
repeats = int(repeats[0])
size = repeats * a.shape[axis or 0]
elif a.shape[ax] == 1:
size = repeats = int(repeats.sum())
else:
size = int(repeats.sum())
else:
size = np.nan
if not isinstance(repeats, Integral):
if repeats.ndim != 1:
raise ValueError("repeats should be 1-d tensor")
broadcast_shape(repeats.shape, a.shape[ax : ax + 1])
else:
size = a.shape[axis or 0] * repeats
shape = a.shape[:ax] + (size,) + a.shape[ax + 1 :]
self.dtype = a.dtype
self.sparse = a.issparse()
inputs = [a]
if isinstance(repeats, Tensor):
inputs.append(repeats)
else:
self.repeats = repeats
return self.new_tensor(inputs, shape, order=TensorOrder.C_ORDER)
[docs]
def repeat(a, repeats, axis=None):
"""
Repeat elements of a tensor.
Parameters
----------
a : array_like
Input tensor.
repeats : int or tensor of ints
The number of repetitions for each element. `repeats` is broadcasted
to fit the shape of the given axis.
axis : int, optional
The axis along which to repeat values. By default, use the
flattened input tensor, and return a flat output tensor.
Returns
-------
repeated_tensor : Tensor
Output array which has the same shape as `a`, except along
the given axis.
See Also
--------
tile : Tile a tensor.
Examples
--------
>>> import maxframe.tensor as mt
>>> mt.repeat(3, 4).execute()
array([3, 3, 3, 3])
>>> x = mt.array([[1,2],[3,4]])
>>> mt.repeat(x, 2).execute()
array([1, 1, 2, 2, 3, 3, 4, 4])
>>> mt.repeat(x, 3, axis=1).execute()
array([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
>>> mt.repeat(x, [1, 2], axis=0).execute()
array([[1, 2],
[3, 4],
[3, 4]])
"""
op = TensorRepeat(axis=axis)
return op(a, repeats)