Source code for maxframe.learn.contrib.xgboost.callback

# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Sequence, Union

from ....serialization.serializables import (
    AnyField,
    BoolField,
    Float32Field,
    Int32Field,
    StringField,
)
from ....udf import BuiltinFunction
from ..utils import TrainingCallback

try:
    from xgboost.callback import EarlyStopping as _EarlyStopping
    from xgboost.callback import LearningRateScheduler as _LearningRateScheduler
except ImportError:
    _LearningRateScheduler = _EarlyStopping = None


class XGBTrainingCallback(TrainingCallback):
    _local_to_remote = {}

    @classmethod
    def from_local(cls, callback_obj):
        cls._load_local_to_remote_mapping(globals())
        return super().from_local(callback_obj)


[docs] class LearningRateScheduler(XGBTrainingCallback): _local_cls = _LearningRateScheduler learning_rates = AnyField("learning_rates", default=None)
[docs] def __init__( self, learning_rates: Union[Callable[[int], float], Sequence[float]], **kw ) -> None: super().__init__(learning_rates=learning_rates, **kw)
def has_custom_code(self) -> bool: return not isinstance(self.learning_rates, (tuple, list, BuiltinFunction))
[docs] class EarlyStopping(XGBTrainingCallback): _local_cls = _EarlyStopping rounds = Int32Field("rounds") metric_name = StringField("metric_name", default=None) data_name = StringField("data_name", default=None) maximize = BoolField("maximize", default=None) save_best = BoolField("save_best", default=None) min_delta = Float32Field("min_delta", default=None)
[docs] def __init__( self, *, rounds: int, metric_name: Optional[str] = None, data_name: Optional[str] = None, maximize: Optional[bool] = None, save_best: Optional[bool] = False, min_delta: float = 0.0, **kw ) -> None: super().__init__( rounds=rounds, metric_name=metric_name, data_name=data_name, maximize=maximize, save_best=save_best, min_delta=min_delta, **kw )