# Copyright 1999-2026 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
from maxframe import opcodes
from maxframe.core import ENTITY_TYPE, EntityData
from maxframe.serialization.serializables import AnyField, BoolField, Int32Field
from maxframe.tensor.core import Tensor
from maxframe.tensor.datasource import tensor as astensor
from maxframe.tensor.operators import TensorOperator, TensorOperatorMixin
class TensorFillDiagonal(TensorOperator, TensorOperatorMixin):
_op_type_ = opcodes.FILL_DIAGONAL
val = AnyField("val", default=None)
wrap = BoolField("wrap", default=None)
# used for chunk
k = Int32Field("k", default=None)
@classmethod
def _set_inputs(cls, op: "TensorFillDiagonal", inputs: List[EntityData]):
super()._set_inputs(op, inputs)
if len(op._inputs) == 2:
op.val = op._inputs[1]
def __call__(self, a, val=None):
inputs = [a]
if val is not None:
inputs.append(val)
return self.new_tensor(inputs, shape=a.shape, order=a.order)
[docs]
def fill_diagonal(a, val, wrap=False):
"""Fill the main diagonal of the given tensor of any dimensionality.
For a tensor `a` with ``a.ndim >= 2``, the diagonal is the list of
locations with indices ``a[i, ..., i]`` all identical. This function
modifies the input tensor in-place, it does not return a value.
Parameters
----------
a : Tensor, at least 2-D.
Tensor whose diagonal is to be filled, it gets modified in-place.
val : scalar
Value to be written on the diagonal, its type must be compatible with
that of the tensor a.
wrap : bool
For tall matrices in NumPy version up to 1.6.2, the
diagonal "wrapped" after N columns. You can have this behavior
with this option. This affects only tall matrices.
See also
--------
diag_indices, diag_indices_from
Notes
-----
This functionality can be obtained via `diag_indices`, but internally
this version uses a much faster implementation that never constructs the
indices and uses simple slicing.
Examples
--------
>>> import maxframe.tensor as mt
>>> a = mt.zeros((3, 3), int)
>>> mt.fill_diagonal(a, 5)
>>> a.execute()
array([[5, 0, 0],
[0, 5, 0],
[0, 0, 5]])
The same function can operate on a 4-D tensor:
>>> a = mt.zeros((3, 3, 3, 3), int)
>>> mt.fill_diagonal(a, 4)
We only show a few blocks for clarity:
>>> a[0, 0].execute()
array([[4, 0, 0],
[0, 0, 0],
[0, 0, 0]])
>>> a[1, 1].execute()
array([[0, 0, 0],
[0, 4, 0],
[0, 0, 0]])
>>> a[2, 2].execute()
array([[0, 0, 0],
[0, 0, 0],
[0, 0, 4]])
The wrap option affects only tall matrices:
>>> # tall matrices no wrap
>>> a = mt.zeros((5, 3), int)
>>> mt.fill_diagonal(a, 4)
>>> a.execute()
array([[4, 0, 0],
[0, 4, 0],
[0, 0, 4],
[0, 0, 0],
[0, 0, 0]])
>>> # tall matrices wrap
>>> a = mt.zeros((5, 3), int)
>>> mt.fill_diagonal(a, 4, wrap=True)
>>> a.execute()
array([[4, 0, 0],
[0, 4, 0],
[0, 0, 4],
[0, 0, 0],
[4, 0, 0]])
>>> # wide matrices
>>> a = mt.zeros((3, 5), int)
>>> mt.fill_diagonal(a, 4, wrap=True)
>>> a.execute()
array([[4, 0, 0, 0, 0],
[0, 4, 0, 0, 0],
[0, 0, 4, 0, 0]])
The anti-diagonal can be filled by reversing the order of elements
using either `numpy.flipud` or `numpy.fliplr`.
>>> a = mt.zeros((3, 3), int)
>>> mt.fill_diagonal(mt.fliplr(a), [1,2,3]) # Horizontal flip
>>> a.execute()
array([[0, 0, 1],
[0, 2, 0],
[3, 0, 0]])
>>> mt.fill_diagonal(mt.flipud(a), [1,2,3]) # Vertical flip
>>> a.execute()
array([[0, 0, 3],
[0, 2, 0],
[1, 0, 0]])
Note that the order in which the diagonal is filled varies depending
on the flip function.
"""
if not isinstance(a, Tensor):
raise TypeError(f"`a` should be a tensor, got {type(a)}")
if a.ndim < 2:
raise ValueError("array must be at least 2-d")
if a.ndim > 2 and len(set(a.shape)) != 1:
raise ValueError("All dimensions of input must be of equal length")
# process val
if isinstance(val, ENTITY_TYPE):
val = astensor(val)
if val.ndim > 1:
val = val.ravel()
val_input = val
else:
val = np.asarray(val)
if val.ndim > 1:
val = val.ravel()
val_input = None
op = TensorFillDiagonal(val=val, wrap=wrap, dtype=a.dtype)
t = op(a, val=val_input)
a.data = t.data