maxframe.tensor.argsort#

maxframe.tensor.argsort(a, axis=-1, kind=None, order=None, *, stable=None, parallel_kind=None, psrs_kinds=None)[源代码]#

返回能够排序张量的索引。

使用 kind 关键字指定的算法沿给定轴执行间接排序。它返回与 a 相同形状的索引张量,该索引沿给定轴按排序顺序索引数据。

参数:
  • a (array_like) -- 要排序的张量。

  • axis (int or None, optional) -- 排序的轴。默认值为 -1(最后一个轴)。如果为 None,则使用展平后的张量。

  • kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, optional) -- 排序算法。默认为 'quicksort'。注意 'stable' 和 'mergesort' 都在底层使用 timsort,且实际实现通常随数据类型而变化。保留 'mergesort' 选项是为了向后兼容。 .. versionchanged:: 1.15.0. 添加了 'stable' 选项。

  • order (str or list of str, optional) -- 当 a 是一个定义了字段的张量时,此参数指定首先比较哪些字段、其次比较哪些字段等。单个字段可以指定为字符串,且并非所有字段都需要指定,但未指定的字段仍会按照它们在 dtype 中出现的顺序用于打破平局。

  • stable (bool, optional) -- 排序稳定性。如果为 True,返回的数组将保持比较相等的 a 值的相对顺序。如果为 FalseNone,则不保证这一点。在内部,此选项会选择 kind='stable'。默认值:None

返回:

index_tensor -- 沿指定 axisa 进行排序的索引张量。如果 a 是一维的,则 a[index_tensor] 会产生排序后的 a。更一般地,np.take_along_axis(a, index_tensor, axis=axis) 总是会生成排序后的 a,无论维度如何。

返回类型:

Tensor, int

参见

sort

描述所使用的排序算法。

lexsort

使用多个键的间接稳定排序。

Tensor.sort

原地排序。

argpartition

间接部分排序。

备注

关于不同排序算法的说明,请参见 sort

示例

一维张量:

>>> import maxframe.tensor as mt
>>> x = mt.array([3, 1, 2])
>>> mt.argsort(x).execute()
array([1, 2, 0])

二维张量:

>>> x = mt.array([[0, 3], [2, 2]])
>>> x.execute()
array([[0, 3],
       [2, 2]])
>>> ind = mt.argsort(x, axis=0)  # sorts along first axis (down)
>>> ind.execute()
array([[0, 1],
       [1, 0]])
#>>> mt.take_along_axis(x, ind, axis=0).execute()  # same as np.sort(x, axis=0)
#array([[0, 2],
#       [2, 3]])
>>> ind = mt.argsort(x, axis=1)  # sorts along last axis (across)
>>> ind.execute()
array([[0, 1],
       [0, 1]])
#>>> mt.take_along_axis(x, ind, axis=1).execute()  # same as np.sort(x, axis=1)
#array([[0, 3],
#       [2, 2]])

N维数组排序元素的索引:

>>> ind = mt.unravel_index(mt.argsort(x, axis=None), x.shape)
>>> ind.execute()
(array([0, 1, 1, 0]), array([0, 0, 1, 1]))
>>> x[ind].execute()  # same as np.sort(x, axis=None)
array([0, 2, 2, 3])

使用键排序:

>>> x = mt.array([(1, 0), (0, 1)], dtype=[('x', '<i4'), ('y', '<i4')])
>>> x.execute()
array([(1, 0), (0, 1)],
      dtype=[('x', '<i4'), ('y', '<i4')])
>>> mt.argsort(x, order=('x','y')).execute()
array([1, 0])
>>> mt.argsort(x, order=('y','x')).execute()
array([0, 1])