maxframe.tensor.digitize#
- maxframe.tensor.digitize(x, bins, right=False)[源代码]#
返回输入张量中每个值所属的区间索引。
返回的每个索引
i满足:若 bins 单调递增,则bins[i-1] <= x < bins[i];若 bins 单调递减,则bins[i-1] > x >= bins[i]。如果 x 中的值超出了 bins 的范围,则根据情况返回 0 或len(bins)。如果 right 为 True,则右区间为闭区间,使得索引i满足:若 bins 单调递增,则bins[i-1] < x <= bins[i];若 bins 单调递减,则bins[i-1] >= x > bins[i]。- 参数:
x (array_like) -- 待分箱的输入张量。
bins (array_like) -- 区间数组。必须是一维且单调的。
right (bool, optional) -- 表示区间是否包含右边界或左边界。默认行为是 (right==False),表示区间不包含右边界。此时左区间端点是开区间,即对于单调递增的区间,bins[i-1] <= x < bins[i] 是默认行为。
- 返回:
out -- 索引输出张量,形状与 x 相同。
- 返回类型:
Tensor of ints
- 抛出:
ValueError -- 如果 bins 不是单调的。
TypeError -- 如果输入的类型是复数。
备注
如果 x 中的值超出了区间范围,使用 digitize 返回的索引去访问 bins 会导致 IndexError。
mt.digitize 是基于 mt.searchsorted 实现的。这意味着使用二分查找来对值进行分箱,在区间数量较大时比以前的线性搜索性能更好。它还取消了输入数组必须是一维的要求。
示例
>>> import maxframe.tensor as mt
>>> x = mt.array([0.2, 6.4, 3.0, 1.6]) >>> bins = mt.array([0.0, 1.0, 2.5, 4.0, 10.0]) >>> inds = mt.digitize(x, bins) >>> inds.execute() array([1, 4, 3, 2])
>>> x = mt.array([1.2, 10.0, 12.4, 15.5, 20.]) >>> bins = mt.array([0, 5, 10, 15, 20]) >>> mt.digitize(x,bins,right=True).execute() array([1, 2, 3, 4, 4]) >>> mt.digitize(x,bins,right=False).execute() array([1, 3, 3, 4, 5])