かえるのプログラミングブログ

プログラミングでつまずいたところとその解決策などを書いていきます。

PyTorch lightening で Titanic 問題解いてみた。

こんばんは、kaerururu (@kaeru_nantoka) です。

今回は、PyTorch lightening を使ってみたいと思ったので 以前公開したカーネル をベースに PyTorch lightening に書き換えたものを公開したのでその紹介をします。


目次


本ポストのモチベーション

公式のGitHub中の人の medium 記事 が充実しており、学習まではつまづくようなところはないと思います。

私も、学習までは既存の自分のコードをうまくテンプレートにはめ込むだけで行けました。しかし、本家 Tutorial 記事をはじめとした記事には推論部分まで書いておらず 「学習はできたけど predict できないぞ。。というか重みどこ?」とひとり沼にはまりそして無事自己解決したので、まとめようと思った次第です。


ソースコード

kaggle の 公開カーネルとして公開しております。 (https://www.kaggle.com/kaerunantoka/pytorch-lightening-for-titanic?scriptVersionId=19673008)

スコアが低いのは、パラメータ調整していないのと前処理を雑にしているからだと思います。もうちょっと行けるはず。。


ポイント

Entity Embedding とは、カテゴリ変数を同じベクトル空間上の座標にマッピングする手法でラベルエンコーディングよりカテゴリ変数同士の意味的な関係も表現できます。

[実装と説明]

cat_dims = [int(train[col].nunique()) for col in categorical_features]
emb_dims = [(x, min(50, (x + 1) // 2)) for x in cat_dims]

# def __init__() 内
# Embedding layers
self.emb_layers = nn.ModuleList([nn.Embedding(x, y)
                                     for x, y in emb_dims])

# def forward() 内
if self.no_of_embs != 0:
   x = [emb_layer(cat_data[:, i])
                  for i,emb_layer in enumerate(self.emb_layers)]
   x = torch.cat(x, 1)
   x = self.emb_dropout_layer(x)

上のように、nn.Embedding を利用し、カテゴリ変数を埋め込みます埋め込みの次元はカテゴリ変数のユニークな値の数の半分か、多すぎる場合は 50 に固定しています。

  • Predict

もう少しスッキリかけるかもしれないですが、とりあえず重みの保存場所とロード方法を docs のこのページ で確認することができたので、推論処理は普段私が使用している predict 関数を使用しました。

[実装と説明]

from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    filepath='/path/to/store/weights.ckpt',
    save_best_only=True,
    verbose=True,
    monitor='val_loss',
    mode='min'
)

trainer = Trainer(checkpoint_callback=checkpoint_callback)

このように Trainer 関数に checkpoint_callback を渡してあげると、

!ls ../input/weights.ckpt/

# _ckpt_epoch_5.ckpt 

のように、checkpoint_callback で指定した filepath に checkpoint が保存されます。

こちらを、

checkpoint = torch.load('../input/weights.ckpt/_ckpt_epoch_5.ckpt')

model.load_state_dict(checkpoint['state_dict'])

のようにロードし、その中の 'state_dict' を load_state_dict の引数に渡せば OK です。


まとめ

  • PyTorch lightening を使用してテーブルデーNLP の実装を、触れ込み通りスッキリ書くことができた。

  • 重みの保存場所、ロードの仕方がわかって推論もできた。

  • カテゴリ変数の前処理の仕方を少し変えた。

これをきっかけに PyTorch lightening を使って画像コンペ用のオレオレコードを整備していこうかなと思っています。

以上、ありがとうございました!