maxframe.tensor.digitize#

maxframe.tensor.digitize(x, bins, right=False)[源代码]#

返回输入张量中每个值所属的区间索引。

返回的每个索引 i 满足:若 bins 单调递增,则 bins[i-1] <= x < bins[i];若 bins 单调递减,则 bins[i-1] > x >= bins[i]。如果 x 中的值超出了 bins 的范围,则根据情况返回 0 或 len(bins)。如果 right 为 True,则右区间为闭区间,使得索引 i 满足:若 bins 单调递增,则 bins[i-1] < x <= bins[i];若 bins 单调递减,则 bins[i-1] >= x > bins[i]

参数:
  • x (array_like) -- 待分箱的输入张量。

  • bins (array_like) -- 区间数组。必须是一维且单调的。

  • right (bool, optional) -- 表示区间是否包含右边界或左边界。默认行为是 (right==False),表示区间不包含右边界。此时左区间端点是开区间,即对于单调递增的区间,bins[i-1] <= x < bins[i] 是默认行为。

返回:

out -- 索引输出张量,形状与 x 相同。

返回类型:

Tensor of ints

抛出:

参见

bincount, histogram, unique, searchsorted

备注

如果 x 中的值超出了区间范围,使用 digitize 返回的索引去访问 bins 会导致 IndexError。

mt.digitize 是基于 mt.searchsorted 实现的。这意味着使用二分查找来对值进行分箱,在区间数量较大时比以前的线性搜索性能更好。它还取消了输入数组必须是一维的要求。

示例

>>> import maxframe.tensor as mt
>>> x = mt.array([0.2, 6.4, 3.0, 1.6])
>>> bins = mt.array([0.0, 1.0, 2.5, 4.0, 10.0])
>>> inds = mt.digitize(x, bins)
>>> inds.execute()
array([1, 4, 3, 2])
>>> x = mt.array([1.2, 10.0, 12.4, 15.5, 20.])
>>> bins = mt.array([0, 5, 10, 15, 20])
>>> mt.digitize(x,bins,right=True).execute()
array([1, 2, 3, 4, 4])
>>> mt.digitize(x,bins,right=False).execute()
array([1, 3, 3, 4, 5])