# Copyright 1999-2026 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
from maxframe import opcodes
from maxframe.core import ENTITY_TYPE, ExecutableTuple
from maxframe.serialization.serializables import (
AnyField,
BoolField,
FieldTypes,
Int32Field,
ListField,
StringField,
)
from maxframe.tensor.core import TENSOR_TYPE, TensorOrder
from maxframe.tensor.datasource import tensor as astensor
from maxframe.tensor.operators import TensorOperator, TensorOperatorMixin
from maxframe.tensor.utils import validate_axis, validate_order
from maxframe.typing_ import EntityType
from maxframe.utils import check_unexpected_kwargs
class TensorPartition(TensorOperatorMixin, TensorOperator):
_op_type_ = opcodes.PARTITION
kth = AnyField("kth")
axis = Int32Field("axis")
kind = StringField("kind")
order = ListField("order", FieldTypes.string)
need_align = BoolField("need_align")
return_value = BoolField("return_value")
return_indices = BoolField("return_indices")
@classmethod
def _set_inputs(cls, op: "TensorPartition", inputs: List[EntityType]):
super()._set_inputs(op, inputs)
if len(op._inputs) > 1:
op.kth = op._inputs[1]
@property
def psrs_kinds(self):
# to keep compatibility with PSRS
# remember when merging data in PSRSShuffle(reduce),
# we don't need sort, thus set psrs_kinds[2] to None
return ["quicksort", "mergesort", None]
@property
def output_limit(self):
return int(bool(self.return_value)) + int(bool(self.return_indices))
def __call__(self, a, kth):
inputs = [a]
if isinstance(kth, TENSOR_TYPE):
inputs.append(kth)
kws = []
if self.return_value:
kws.append(
{
"shape": a.shape,
"order": a.order,
"type": "sorted",
"dtype": a.dtype,
}
)
if self.return_indices:
kws.append(
{
"shape": a.shape,
"order": TensorOrder.C_ORDER,
"type": "argsort",
"dtype": np.dtype(np.int64),
}
)
ret = self.new_tensors(inputs, kws=kws)
if len(kws) == 1:
return ret[0]
return ExecutableTuple(ret)
def _check_kth_dtype(dtype):
if not np.issubdtype(dtype, np.integer):
raise TypeError("Partition index must be integer")
def _validate_kth_value(kth, size):
kth = np.where(kth < 0, kth + size, kth)
if np.any((kth < 0) | (kth >= size)):
invalid_kth = next(k for k in kth if k < 0 or k >= size)
raise ValueError(f"kth(={invalid_kth}) out of bounds ({size})")
return kth
def _validate_partition_arguments(a, kth, axis, kind, order, kw):
a = astensor(a)
if axis is None:
a = a.flatten()
axis = 0
else:
axis = validate_axis(a.ndim, axis)
if isinstance(kth, ENTITY_TYPE):
kth = astensor(kth)
_check_kth_dtype(kth.dtype)
else:
kth = np.atleast_1d(kth)
kth = _validate_kth_value(kth, a.shape[axis])
if kth.ndim > 1:
raise ValueError("object too deep for desired array")
if kind != "introselect":
raise ValueError(f"{kind} is an unrecognized kind of select")
# if a is structure type and order is not None
order = validate_order(a.dtype, order)
need_align = kw.pop("need_align", None)
check_unexpected_kwargs(kw)
return a, kth, axis, kind, order, need_align
[docs]
def partition(a, kth, axis=-1, kind="introselect", order=None, **kw):
r"""
Return a partitioned copy of a tensor.
Creates a copy of the tensor with its elements rearranged in such a
way that the value of the element in k-th position is in the
position it would be in a sorted tensor. All elements smaller than
the k-th element are moved before this element and all equal or
greater are moved behind it. The ordering of the elements in the two
partitions is undefined.
Parameters
----------
a : array_like
Tensor to be sorted.
kth : int or sequence of ints
Element index to partition by. The k-th value of the element
will be in its final sorted position and all smaller elements
will be moved before it and all equal or greater elements behind
it. The order of all elements in the partitions is undefined. If
provided with a sequence of k-th it will partition all elements
indexed by k-th of them into their sorted position at once.
axis : int or None, optional
Axis along which to sort. If None, the tensor is flattened before
sorting. The default is -1, which sorts along the last axis.
kind : {'introselect'}, optional
Selection algorithm. Default is 'introselect'.
order : str or list of str, optional
When `a` is a tensor with fields defined, this argument
specifies which fields to compare first, second, etc. A single
field can be specified as a string. Not all fields need be
specified, but unspecified fields will still be used, in the
order in which they come up in the dtype, to break ties.
Returns
-------
partitioned_tensor : Tensor
Tensor of the same type and shape as `a`.
See Also
--------
Tensor.partition : Method to sort a tensor in-place.
argpartition : Indirect partition.
sort : Full sorting
Notes
-----
The various selection algorithms are characterized by their average
speed, worst case performance, work space size, and whether they are
stable. A stable sort keeps items with the same key in the same
relative order. The available algorithms have the following
properties:
================= ======= ============= ============ =======
kind speed worst case work space stable
================= ======= ============= ============ =======
'introselect' 1 O(n) 0 no
================= ======= ============= ============ =======
All the partition algorithms make temporary copies of the data when
partitioning along any but the last axis. Consequently,
partitioning along the last axis is faster and uses less space than
partitioning along any other axis.
The sort order for complex numbers is lexicographic. If both the
real and imaginary parts are non-nan then the order is determined by
the real parts except when they are equal, in which case the order
is determined by the imaginary parts.
Examples
--------
>>> import maxframe.tensor as mt
>>> a = mt.array([3, 4, 2, 1])
>>> mt.partition(a, 3).execute()
array([2, 1, 3, 4])
>>> mt.partition(a, (1, 3)).execute()
array([1, 2, 3, 4])
"""
return_indices = kw.pop("return_index", False)
a, kth, axis, kind, order, need_align = _validate_partition_arguments(
a, kth, axis, kind, order, kw
)
op = TensorPartition(
kth=kth,
axis=axis,
kind=kind,
order=order,
need_align=need_align,
return_value=True,
return_indices=return_indices,
dtype=a.dtype,
gpu=a.op.gpu,
)
return op(a, kth)