maxframe.tensor.where#

maxframe.tensor.where(condition, x=None, y=None)[源代码]#

根据 condition 返回 xy 中的元素。

如果只提供了 condition,则返回 condition.nonzero()

参数:
  • condition (array_like, bool) -- 为 True 时,返回 x,否则返回 y

  • x (array_like, optional) -- 用于选择的值。xycondition 需要能广播到相同的形状。

  • y (array_like, optional) -- 用于选择的值。xycondition 需要能广播到相同的形状。

返回:

out -- 如果同时指定了 xy,输出张量在 condition 为 True 的位置包含 x 的元素,其他位置包含 y 的元素。如果只提供了 condition,则返回元组 condition.nonzero(),即 condition 为 True 的索引。

返回类型:

Tensor or tuple of Tensors

参见

nonzero, choose

备注

如果提供了 xy,且输入数组是一维的,则 where 等价于:

[xv if c else yv for (c,xv,yv) in zip(condition,x,y)]

示例

>>> import maxframe.tensor as mt
>>> mt.where([[True, False], [True, True]],
...          [[1, 2], [3, 4]],
...          [[9, 8], [7, 6]]).execute()
array([[1, 8],
       [3, 4]])
>>> mt.where([[0, 1], [1, 0]]).execute()
(array([0, 1]), array([1, 0]))
>>> x = mt.arange(9.).reshape(3, 3)
>>> mt.where( x > 5 ).execute()
(array([2, 2, 2]), array([0, 1, 2]))
>>> mt.where(x < 5, x, -1).execute()               # Note: broadcasting.
array([[ 0.,  1.,  2.],
       [ 3.,  4., -1.],
       [-1., -1., -1.]])

查找 x 中在 goodvalues 中的元素的索引。

>>> goodvalues = [3, 4, 7]
>>> ix = mt.isin(x, goodvalues)
>>> ix.execute()
array([[False, False, False],
       [ True,  True, False],
       [False,  True, False]])
>>> mt.where(ix).execute()
(array([1, 1, 2]), array([0, 1, 1]))