Source code for maxframe.learn.contrib.lightgbm._train

# Copyright 1999-2026 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import logging
from typing import List

from maxframe import opcodes
from maxframe.core import OutputType
from maxframe.core.operator import ObjectOperator, ObjectOperatorMixin
from maxframe.learn.contrib.lightgbm.core import (
    Booster,
    BoosterData,
    LGBMScikitLearnBase,
)
from maxframe.learn.contrib.models import to_remote_model
from maxframe.learn.contrib.utils import TrainingCallback
from maxframe.serialization.serializables import (
    AnyField,
    BoolField,
    DictField,
    FieldTypes,
    FunctionField,
    Int32Field,
    KeyField,
    ListField,
    StringField,
)
from maxframe.typing_ import EntityType

logger = logging.getLogger(__name__)


class LGBMTrain(ObjectOperator, ObjectOperatorMixin):
    _op_type_ = opcodes.LGBM_TRAIN
    _op_module_ = "learn.contrib.lightgbm"

    params = DictField("params", key_type=FieldTypes.string, default_factory=dict)
    train_set = KeyField("train_set", default=None)
    num_boost_round = Int32Field("num_boost_round", default=None)
    valid_sets = ListField("valid_sets", FieldTypes.key, default_factory=list)
    valid_names = ListField("valid_names", FieldTypes.string, default_factory=list)
    feval = FunctionField("feval", default=None)
    init_model = KeyField("init_model", default=None)
    feature_name = AnyField("feature_name", default=None)
    categorical_feature = AnyField("categorical_feature", default=None)
    keep_training_booster = BoolField("keep_training_booster", default=False)
    callbacks = ListField(
        "callbacks",
        field_type=FunctionField.field_type,
        default=None,
        on_serialize=TrainingCallback.from_local,
    )
    tree_learner = StringField("tree_learner", default=None)
    timeout = Int32Field("timeout", default=None)
    # indicating shape of the predicted data of the model
    num_class = Int32Field("num_class", default=None)

    def __init__(self, gpu=None, **kw):
        super().__init__(gpu=gpu, **kw)
        if self.output_types is None:
            self.output_types = [OutputType.object]
        if self.has_evals_result:
            self.output_types.append(OutputType.object)

    @classmethod
    def _set_inputs(cls, op: "LGBMTrain", inputs: List[EntityType]):
        super()._set_inputs(op, inputs)
        input_it = iter(op.inputs)
        op.train_set = next(input_it)
        op.valid_sets = list(itertools.islice(input_it, len(op.valid_sets or [])))
        if op.init_model is not None:
            op.init_model = next(input_it)

    def __call__(self, evals_result):
        inputs = [self.train_set]
        if self.has_evals_result:
            inputs.extend(self.valid_sets)
        if self.init_model is not None:
            inputs.append(self.init_model)
        kws = [{"object_class": Booster}, {}]
        return self.new_tileables(inputs, kws=kws, evals_result=evals_result)[0]

    @property
    def output_limit(self):
        return 2 if self.has_evals_result else 1

    def has_custom_code(self) -> bool:
        if not self.callbacks:
            return False
        return any(
            not isinstance(cb, TrainingCallback) or cb.has_custom_code()
            for cb in self.callbacks
        )

    @property
    def has_evals_result(self) -> bool:
        return bool(self.valid_sets)


def _get_lgbm_booster(init_model):
    import lightgbm

    if isinstance(init_model, (LGBMScikitLearnBase, lightgbm.LGBMModel)):
        init_model = init_model.booster_

    if isinstance(init_model, (Booster, BoosterData)):
        return init_model
    elif isinstance(init_model, lightgbm.Booster):
        return to_remote_model(init_model, model_cls=Booster)
    raise ValueError(f"Cannot use {type(init_model)} as init_model")


[docs] def train( params, train_set, num_boost_round=100, valid_sets=None, valid_names=None, feval=None, init_model=None, keep_training_booster=False, callbacks=None, num_class=2, evals_result=None, **kwargs, ): evals_result = evals_result if evals_result is not None else dict() timeout = kwargs.pop("timeout", 120) session = kwargs.pop("session", None) run_kwargs = kwargs.pop("run_kwargs", dict()) if valid_sets and valid_names: if not isinstance(valid_names, list): raise TypeError("valid_names must be a list of strings") if len(valid_names) != len(valid_sets): raise ValueError("size of valid_names must match size of valid_sets") if init_model is not None: init_model = _get_lgbm_booster(init_model) data = LGBMTrain( params=params, train_set=train_set, num_boost_round=num_boost_round, valid_sets=valid_sets, valid_names=valid_names, feval=feval, init_model=init_model, keep_training_booster=keep_training_booster, callbacks=callbacks, num_class=num_class, evals_result=evals_result, timeout=timeout, **kwargs, )(evals_result) if valid_sets: data.execute(session=session, **run_kwargs) return data