Source code for maxframe.learn.contrib.xgboost.train

# Copyright 1999-2026 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections import OrderedDict
from typing import List

from maxframe import opcodes
from maxframe.core import ENTITY_TYPE, EntityData, OutputType
from maxframe.core.operator import ObjectOperator, ObjectOperatorMixin
from maxframe.learn.contrib.models import to_remote_model
from maxframe.learn.contrib.utils import TrainingCallback
from maxframe.learn.contrib.xgboost.core import Booster, BoosterData, XGBScikitLearnBase
from maxframe.learn.contrib.xgboost.dmatrix import ToDMatrix, to_dmatrix
from maxframe.serialization.serializables import (
    AnyField,
    BoolField,
    DictField,
    FieldTypes,
    FunctionField,
    Int16Field,
    Int64Field,
    KeyField,
    ListField,
)

logger = logging.getLogger(__name__)


def _on_serialize_evals(evals_val):
    if evals_val is None:
        return None
    return [list(x) for x in evals_val]


class XGBTrain(ObjectOperator, ObjectOperatorMixin):
    _op_type_ = opcodes.XGBOOST_TRAIN

    params = DictField("params", key_type=FieldTypes.string, default=None)
    dtrain = KeyField("dtrain", default=None)
    evals = ListField("evals", on_serialize=_on_serialize_evals, default=None)
    obj = FunctionField("obj", default=None)
    feval = FunctionField("feval", default=None)
    maximize = BoolField("maximize", default=None)
    early_stopping_rounds = Int64Field("early_stopping_rounds", default=None)
    verbose_eval = AnyField("verbose_eval", default=None)
    xgb_model = KeyField("xgb_model", default=None)
    callbacks = ListField(
        "callbacks",
        field_type=FunctionField.field_type,
        default=None,
        on_serialize=TrainingCallback.from_local,
    )
    custom_metric = FunctionField("custom_metric", default=None)
    num_boost_round = Int64Field("num_boost_round", default=10)
    num_class = Int64Field("num_class", default=None)
    _has_evals_result = BoolField("has_evals_result", default=False)
    output_ndim = Int16Field("output_ndim", default=None)

    def __init__(self, gpu=None, **kw):
        if kw.get("evals_result") is not None:
            kw["_has_evals_result"] = True
        super().__init__(gpu=gpu, **kw)
        if self.output_types is None:
            self.output_types = [OutputType.object]
        if self.has_evals_result:
            self.output_types.append(OutputType.object)

    def has_custom_code(self) -> bool:
        if not self.callbacks:
            return False
        return any(
            not isinstance(cb, TrainingCallback) or cb.has_custom_code()
            for cb in self.callbacks
        )

    @classmethod
    def _set_inputs(cls, op: "XGBTrain", inputs: List[EntityData]):
        super()._set_inputs(op, inputs)
        input_it = iter(op._inputs)
        op.dtrain = next(input_it)
        if op.evals:
            evals_dict = OrderedDict(op.evals)
            new_evals_dict = OrderedDict()
            for val in evals_dict.values():
                new_key = next(input_it)
                new_evals_dict[new_key] = val
            op.evals = list(new_evals_dict.items())
        if op.xgb_model:
            op.xgb_model = next(input_it)

    def __call__(self, evals_result):
        inputs = [self.dtrain]
        if self.has_evals_result:
            inputs.extend(e[0] for e in self.evals)
        if self.xgb_model is not None:
            inputs.append(self.xgb_model)
        kws = [{"object_class": Booster}, {}]
        return self.new_tileables(inputs, kws=kws, evals_result=evals_result)[0]

    @property
    def output_limit(self):
        return 2 if self.has_evals_result else 1

    @property
    def has_evals_result(self) -> bool:
        return self._has_evals_result or self.evals


def _get_xgb_booster(xgb_model):
    import xgboost

    if isinstance(xgb_model, (XGBScikitLearnBase, xgboost.XGBModel)):
        xgb_model = xgb_model.get_booster()

    if isinstance(xgb_model, (Booster, BoosterData)):
        return xgb_model
    elif isinstance(xgb_model, xgboost.Booster):
        return to_remote_model(xgb_model, model_cls=Booster)
    raise ValueError(f"Cannot use {type(xgb_model)} as xgb_model")


[docs] def train( params, dtrain, evals=None, evals_result=None, xgb_model=None, num_class=None, **kwargs, ): """ Train XGBoost model in MaxFrame manner. Parameters ---------- Parameters are the same as `xgboost.train`. Note that train is an eager-execution API if evals is passed, thus the call will be blocked until training finished. Returns ------- results: Booster """ evals_result = evals_result if evals_result is not None else dict() processed_evals = [] session = kwargs.pop("session", None) run_kwargs = kwargs.pop("run_kwargs", dict()) if evals: for eval_dmatrix, name in evals: if not isinstance(name, str): raise TypeError("evals must a list of pairs (DMatrix, string)") if hasattr(eval_dmatrix, "op") and isinstance(eval_dmatrix.op, ToDMatrix): processed_evals.append((eval_dmatrix, name)) else: processed_evals.append((to_dmatrix(eval_dmatrix), name)) if isinstance(xgb_model, ENTITY_TYPE): xgb_model = to_remote_model( xgb_model, model_cls=Booster, extractor=XGBScikitLearnBase._extract_booster ) elif xgb_model is not None: xgb_model = _get_xgb_booster(xgb_model) data = XGBTrain( params=params, dtrain=dtrain, evals=processed_evals, evals_result=evals_result, xgb_model=xgb_model, num_class=num_class, **kwargs, )(evals_result) if evals: data.execute(session=session, **run_kwargs) return data