tkm2261's blog

研究員(OR屋) → データ分析官 → MLエンジニア → ニートがデータ分析諸々書いてます

細かすぎて伝わらないLightGBM活用法 (callback関数)

皆様tkm2261です。この頃連投が続いてますが、

最近まで参加していたInstacart Market Basket Analysis | Kaggleで色々やったので残しておこうと思います。

このcallback関数は便利ですが、Kaggleなどでヘビーに使う人以外ここまでしないと思うので活躍するかは微妙なところです。。。

LightGBMのtrain関数を読み解く

xgboostもそうですが、lightgbmにもtrain()という関数がありLightGBMユーザはこれを使って学習を実行します。

scikit-learn APIも内部ではこの関数を呼んでいるので同じです。

この引数にcallbacksというのがあり、殆どのユーザは使っていないと思います。(私も今回初)

このcallbacksの活用法が今回のトピックになります。

train関数の実装を見てみると, 192行目ぐらいでBoostingを回しています。

LightGBM/engine.py at master · Microsoft/LightGBM · GitHub

# lightgbm/engine.py: 192
    for i in range_(init_iteration, init_iteration + num_boost_round):
        for cb in callbacks_before_iter:
            cb(callback.CallbackEnv(model=booster,
                                    params=params,
                                    iteration=i,
                                    begin_iteration=init_iteration,
                                    end_iteration=init_iteration + num_boost_round,
                                    evaluation_result_list=None))

        booster.update(fobj=fobj)

        evaluation_result_list = []
        # check evaluation result.
        if valid_sets is not None:
            if is_valid_contain_train:
                evaluation_result_list.extend(booster.eval_train(feval))
            evaluation_result_list.extend(booster.eval_valid(feval))
        try:
            for cb in callbacks_after_iter:
                cb(callback.CallbackEnv(model=booster,
                                        params=params,
                                        iteration=i,
                                        begin_iteration=init_iteration,
                                        end_iteration=init_iteration + num_boost_round,
                                        evaluation_result_list=evaluation_result_list))
        except callback.EarlyStopException as earlyStopException:
            booster.best_iteration = earlyStopException.best_iteration + 1
            evaluation_result_list = earlyStopException.best_score
            break

意外に思われるかもしれませんが、boostingのfor文はpython側で回ってます。

early stoppingもこっちで例外として書いてあるのでcallback関数さえかければループ毎にかなり動的に書くことが出来ます。

例えばearly stoppingの条件をより複雑にしたり、loggingをちゃんと仕込んで普通は標準出力に吐かれて保存が難しい学習の様子をファイルに残すことが出来ます。

train関数に渡したcallback関数はcallbacks_after_iterに渡るのでこちらをこれから見ていきます。

callback関数の仕様

custom objectiveやcustom metricと異なりcallbackの仕様は英語でもドキュメントがありません。

ただ実装は素直なので見ていきます。

文字で説明してもアレなので、先に私のコンペの実装を見せます。

def get_pred_metric(pred, dtrain):
    """予測値無理やり取るmetric
    """
    return 'pred', pred, True

def callback(env):
    """ligthgbm callback関数(instacartコンペ)
   
    :param lightgbm.callback.CallbackEnv env: 学習中のデータ
    """

    # 10回毎に実行
    if (env.iteration + 1) % 10 != 0:
        return

    clf = env.model                 # 学習中モデル
    trn_env = clf.train_set         # 学習データ (ラベルだけが入ってる※後述)
    val_env = clf.valid_sets[0]     # 検証データ (ラベルだけが入ってる※後述)

    # 無理やり予測値をとる
    preds = [ele[2] for ele in clf.eval_train(get_pred_metric) if ele[1] == 'pred'][0]
    # ラベル取得Cython用に型を指定
    labels = trn_env.get_label().astype(np.int)

    # ユーザ毎のしきい値で予測を0-1にしてデータ毎に正解なら1, 不正解なら0を返す関数
    # (list_idxはgroupでグローバルでアクセス)
    res = f1_group_idx(labels, preds, list_idx).astype(np.bool)
   
    # 間違ってるデータのウェイトを上げるて正解のデータのウェイトを下げる
    weight = trn_env.get_weight()
    if weight is None:
        weight = np.ones(preds.shape[0])
    weight[res] *= 0.8
    weight[~res] *= 1.25

    trn_env.set_weight(weight)

    # 検証データのmetricを計算
    preds = [ele[2] for ele in clf.eval_valid(get_pred_metric) if ele[1] == 'pred'][0]
    labels = val_env.get_label().astype(np.int)
    t = time.time()
    res = f1_group(labels, preds, list_idx)
    sc = np.mean(res)

    logger.info('cal [{}] {} {}'.format(env.iteration + 1, sc, time.time() - t))

受け取る引数はひとつ (lightgbm.callback.CallbackEnv)

ただの名前付きタプルです。lightgbmユーザなら名前で大体中身が分かるはずです。

唯一よく知らないのがmodel(Booster)かと思うので、この後見ていきます。

begin_iterationとend_iterationがあるのはinit_model指定して学習を途中からやる場合があるからです。

CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
     "params",
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

LightGBM/callback.py at master · Microsoft/LightGBM · GitHub

データのアクセスはAPI経由

callback関数最大の難関は、lightgbmの学習データをpythonから直接アクセス出来ないことです。

なぜかというと、学習はCでやるためCのオブジェクトでデータを持ってるためです。

メモリ的にPython側でも持つわけには行かないので仕方ありません。

もっと良いやり方があるかもしれませんが、学習データにはcustom metricでアクセスするのが一番楽そうでした。

なのでget_pred_metricみたいな一切計算しないで予測値を返すだけのcustom metricを実装して無理やり取得します。

preds = [ele[2] for ele in clf.eval_train(get_pred_metric) if ele[1] == 'pred'][0]
preds = [ele[2] for ele in clf.eval_valid(get_pred_metric) if ele[1] == 'pred'][0]

Booster.eval_trainとBooster.eval_testで学習データと検証データのそれぞれの予測値がとれます。

データは複数渡せるので、最後に先頭のだけ取ります。

ラベルは普通にget_label()でとれます。

その他の事はBoosterの実装を見てください。

LightGBM/basic.py at master · Microsoft/LightGBM · GitHub

あとは好きに実装

私の場合は、metricの計算が遅いので10回反復毎にしたり、反復毎にデータのウェイトを変えたり、ログに書いたりといった活用をしました。

ただ反復毎にデータのウェイトを変えるのはC側の変更も必要なので注意下さい。

自前early stoppingのやり方

↑のboosting反復の実装みれば分かる通り、lightgbm.callback.EarlyStopExceptionを上げるだけです。

callback関数内で好きに実装してraiseしましょう。

class EarlyStopException(Exception):
    """Exception of early stopping.
    Parameters
    ----------
    best_iteration : int
        The best iteration stopped.
    """
    def __init__(self, best_iteration, best_score):
        super(EarlyStopException, self).__init__()
        self.best_iteration = best_iteration
        self.best_score = best_score

LightGBM/callback.py at master · Microsoft/LightGBM · GitHub

ただ未検証なので、ちゃんと巻き戻るかとか検証したら教えて下さい。。。

一応書いたけど。。。

ここまで学習を制御すること無いので普通使わないですが、忘れそうなので備忘で記事にしました。