かえるのプログラミングブログ

プログラミングでつまずいたところとその解決策などを書いていきます。

kaggle カーネルで commit せずに output file を DL する.

こんばんは、かえるるるです。

既知とは思いますが、表題の内容を自分用にメモを残しておきます。

内容

kaggle kernel 上で、

df.to_csv('/path/to/file/df.csv, index=False')

などのコマンドで出力したファイルは、右上の commit ボタンを押して、 カーネルの実行が完了した後の preview 画面の output 欄から DL できることは広く知られていると思います。

実は、以下の写真のコマンド を実行し、out セルに出てくるリンクを押下すると、commit ボタンを押さないままの編集画面上からも出力 file のDL が可能になります。

f:id:kaeru_nantoka_py:20200128032031p:plain

参考

https://www.kaggle.com/getting-started/58426

以上、ありがとうございました。

自然数の順番になっているリストが欲しい

こんばんは、かえるるるです。

今回は、

os.listdir() などで

['file11.txt', 'file2.txt', 'file4.txt', 'file5.txt', 'file8.txt', 'file3.txt', 'file1.txt', 'file13.txt', 'file14.txt', 'file6.txt', 'file10.txt', 'file12.txt', 'file9.txt', 'file7.txt', 'file0.txt']

のようになって返ってきたファイルたちを、

['file0.txt', 'file1.txt', 'file2.txt', 'file3.txt', 'file4.txt', 'file5.txt', 'file6.txt', 'file7.txt', 'file8.txt', 'file9.txt', 'file10.txt', 'file11.txt', 'file12.txt', 'file13.txt', 'file14.txt']

のように自然数の順番で使用したい場面に出くわして、少しだけ困ったのでその解決策を記しておこうと思った次第です。

0. 準備

['file11.txt', 'file2.txt', 'file4.txt', 'file5.txt', 'file8.txt', 'file3.txt', 'file1.txt', 'file13.txt', 'file14.txt', 'file6.txt', 'file10.txt', 'file12.txt', 'file9.txt', 'file7.txt', 'file0.txt']

このようなリストを擬似的に作成するコードはこちらです。

import random
import re
from natsort import natsorted

lst = [str(i) for i in range(15)]

random.shuffle(lst)  # 順番をバラバラにする。

file_list = [f'file{i}.txt' for i in lst]

print(file_list)

1. sorted(list)

まず、無秩序に並んでいるリストに一定の秩序を与えたい時、 sorted() 関数の使用を思いつくと思います。

print(sorted(file_list))

['file0.txt', 'file1.txt', 'file10.txt', 'file11.txt', 'file12.txt', 'file13.txt', 'file14.txt', 'file2.txt', 'file3.txt', 'file4.txt', 'file5.txt', 'file6.txt', 'file7.txt', 'file8.txt', 'file9.txt']

このように、0, 1, 10, 11, ..., 2, 20, 21, ... というような秩序だってはいますがこれじゃねえんだよ。。というリストが返ってきます。

2. sorted(list, key=function)

SNS で呟くと、sorted には key を渡せまっせ。とアドバイスをいただきました。ありがとうございます。

key にはどこを参照して並び替えたいのかを返す関数を渡せば良いみたいです。

'file' と '.txt' を除いた、数字の部分を参照して欲しいので以下のような helper 関数を作成しました。

def return_check_item(x):
    x = re.sub(f'^{"file"}', '', x)
    x = re.sub(f'{".txt"}$', '', x)
    return int(x)

そして、

print(sorted(file_list, key=return_check_item))

['file0.txt', 'file1.txt', 'file2.txt', 'file3.txt', 'file4.txt', 'file5.txt', 'file6.txt', 'file7.txt', 'file8.txt', 'file9.txt', 'file10.txt', 'file11.txt', 'file12.txt', 'file13.txt', 'file14.txt']

欲しい順番になりました。

3. natsorted(list)

sorted にhelper 関数を渡さなくてもよしなにしてくれる関数がありました。

pip3 install natsort=='5.5.0'


from natsort import natsorted

print(natsorted(file_list))

['file0.txt', 'file1.txt', 'file2.txt', 'file3.txt', 'file4.txt', 'file5.txt', 'file6.txt', 'file7.txt', 'file8.txt', 'file9.txt', 'file10.txt', 'file11.txt', 'file12.txt', 'file13.txt', 'file14.txt']

実務利用であれば、helper 関数を書いた方がいい気がしますが覚えておいて損はないでしょう。

4. 0 padding

SNS で呟いたら 0 埋めしたらええやんとアドバイスをいただきました。 やってみましょう。

['file11.txt', 'file2.txt', 'file4.txt', 'file5.txt', 'file8.txt', 'file3.txt', 'file1.txt', 'file13.txt', 'file14.txt', 'file6.txt', 'file10.txt', 'file12.txt', 'file9.txt', 'file7.txt', 'file0.txt']

このように名前がすでについているものの使用を想定しているので、rename 関数を用意します。 ファイル名文字列を file 数字 .txt に分割し、数字のところに zfill 関数で任意の桁数になるまで 0で埋めたものを代入します。

def rename(x):
    prefix = "file"
    suffix = ".txt"

    x = re.sub(f'^{prefix}', '', x)
    x = re.sub(f'{suffix}$', '', x)
    
    return prefix+f"{str(x).zfill(5)}"+suffix
padded_lst = [rename(i) for i in file_list]
padded_lst

['file00007.txt',
 'file00000.txt',
 'file00014.txt',
 'file00006.txt',
 'file00009.txt',
 'file00010.txt',
 'file00013.txt',
 'file00011.txt',
 'file00012.txt',
 'file00005.txt',
 'file00004.txt',
 'file00008.txt',
 'file00003.txt',
 'file00001.txt',
 'file00002.txt']
sorted(padded_lst)

['file00000.txt',
 'file00001.txt',
 'file00002.txt',
 'file00003.txt',
 'file00004.txt',
 'file00005.txt',
 'file00006.txt',
 'file00007.txt',
 'file00008.txt',
 'file00009.txt',
 'file00010.txt',
 'file00011.txt',
 'file00012.txt',
 'file00013.txt',
 'file00014.txt']

簡単でした。

まとめ

リストについて少し詳しくなった。 SNS すごい。 ありがとうございました。

参考にしたサイトなど

[python] sorted のkeyってなんぞ - Qiita

OS - [python] os.listdir()の出力を数字順にsortして、それらを順番にpd.read_csvで読み込みたい|teratail

Pythonで文字列・数値をゼロ埋め(ゼロパディング) | note.nkmk.me

私がよく使う kaggle api command まとめた

こんばんは。kaerururu です。

今回は個人的にざっと調べてみても出てこなかった kaggle api command の使い方のチートシートのようなものを作ったので共有致します。

GitHub のリンク

https://github.com/osuossu8/Utils/blob/master/kaggle_api_usage.py

( i ) 解説など

セットアップやコンペデータの DL などはカレーさんのブログ他、Qiita 記事などに詳しく書かれているので参照されたし。

public のユーザーデータセットの DL についても kaggle datasets download -w (www.kaggle.com 以下をコピペ) で DL できるので問題ないと思います。

個人的にわからなかったのは dataset にファイルを api 経由でアップロードすることだったので少し解説します。

例えば、fuga.txt というファイルをアップロードしたいケース。

まず、hoge/fuga.txt となるように hoge というディレクトリを用意します。

そして、

kaggle datasets init -p hoge

を実行すると hoge/dataset-metadata.json というファイルが生成されます。

以下のようになっているので、

{
 "title": "INSERT_TITLE_HERE",
 "id": "kaerunantoka/INSERT_SLUG_HERE",
 "licenses": [
  {
   "name": "CCO-1.0"
  }
 ]
}

このように、 title と id を任意の文字列で書き換えます。

{
 "title": "hogefuga", # The dataset title must be between 6 and 50 character
 "id": "kaerunantoka/hogefuga",
 "licenses": [
  {
   "name": "CCO-1.0"
  }
 ]
}

title は 6-50 文字の範囲で書かないと怒られます。

最後に、

kaggle datasets create -p hoge

を実行すると、 https://www.kaggle.com/username/(入力した title)

のパスに fuga.txt がアップロードされます。

( i i ) まとめ

画像コンペやテキストコンペなどのコードコンペで何かとデータセットのアップロードの必要性が多い今日この頃、command でデータのアップロードができると非常にストレスフリーなので活用していただけると幸いです。

( i i i ) 参考

・導入部分(日本語)

http://www.currypurin.com/entry/2018/kaggle-api

・dataset のアップロード関連 (英語)

https://codelabs.developers.google.com/codelabs/upload-update-data-kaggle-api/index.html?index=..%2F..index#0

・本家本元 (英語)

https://github.com/Kaggle/kaggle-api/tree/84895cea61af708a24f0e8ad8307d570e82d8097

2019 年を振り返る

ごきげんよう。かえるるるです🐸

もう2019年もあと数日で終了しますので、いくつかの観点で振り返ってみようと思います。2019年は私の人生にとって激動の一年と言っても過言ではありません。それくらい濃い一年でした。


1. kaggle

私の2019年を語る上で kaggle は欠かせません。後述します転職も kaggle での活動がなければ業務未経験・情報系の学位なしで現職に就くことはできなかったでしょう。

こちらは、コンペ、kaggle meetup tokyo、kaggle days tokyo の三つの観点で振り返っていきます。

i ) コンペ

solo :

・ ~2月 Quora Insincere Questions Classification 🥈

・ Jigsaw Unintended Bias in Toxicity Classification 347/3165

・ APTOS 2019 Blindness Detection 1199/2931

・ Severstal: Steel Defect Detection 554/2431

team :

・~4月 PetFinder.my Adoption Prediction 🏅

・~6月 Freesound Audio Tagging 2019 🏅

・12月 Kaggle Days Tokyo

ソロでの戦績は2月に quora コンペでソロ銀メダルを何とか獲得できたこと以外ではパッとしないです。

Jigsaw コンペは kaggle NLP コンペにおいてBERT が実質デファクトスタンダードになった印象的なコンペでした。

過去コンペなどから従来手法などを漁り、最後まで頑張ったのですが、戦略ミスでした。(適切な loss関数を設定した single BERT-Learge で銀メダルスコアが出せたとか)

眼球コンペと鉄コンペは画像コンペでした。 眼球コンペは分類タスクでSE-ResNeXt や Efficient Net でごにょごにょする augmentation をごにょごにょするなど簡単に実験できそうなことをやって満足して離脱。鉄コンペは segmentation タスクということで segmentation_models.pytorch やスターターカーネルをいじってそれっぽいスコアを出せたものの、そもそものsegmentation の基礎的な部分の知識が抜け落ちていたため細かいチューニングなどができず離脱。この辺りから kaggle 上の計算資源が絞られてきたので、もっと色々実験しておくんだったなあと後悔しています。

一方チーム戦においては、3つのコンペにおいて最高のメンバーで参加することができました。

ソロ参加だと業務ガーとかスコアが思うように伸びずモチベーションマネジメント的な側面で途中離脱しがちですが、チームだと時間を捻出して kaggle に取り組めました。

チーム戦だと強いチームメイトの取り組み方やコード資産を参照できたり、自分の理解の足りていないところに対してアドバイスをもらえたりするなど学びが多いです。

あとシンプルに楽しいです。pet コンペのチームメイト[wodori メンバー]とは忘年会もしましたし、チームの時に使用していた slack は未だに活発で huggingface が資金調達したよねみたいな tech 系の雑談をしたりして盛り上がっていますし、freesound のチームメイトとも、決起会をしたり、打ち上げに焼肉を食べたりなどしました。

i i ) kaggle meetup tokyo

2018年12月の会に聴講者として初参加し、2019年の会では wodori での winner solution を kaggler の皆さんの前でお話しするという機会をいただきました。自分のパートは当時の自分の知識でできることを全部やったぞくらいのものでしたが、団体ではありますがイベント登壇としては初めての経験をさせてもらえました。

i i i ) kaggle days tokyo

東京で公式の全世界 kaggle オフ会が開催されるということで2日とも参加いたしました。 1日目のプレゼンテーションパート、2日めのコンペティションパートいずれも楽しかったです。

一方で反省が一つ。もっと、海外からの参加者の方とコミュニケーションをとるべきでした。共通の趣味[kaggle]があり、わざわざ日本にいらっしゃる方ということなのでこれほど会話がしやすいシチュエーションもないでしょう。他の勉強会でもそうですが、新しい人に話しかけるというのを心がけていきたいなと思いました。


2. 転職

2019年4月より、自然言語処理に力を入れており、日本語版の ELMo やBERT のモデルの公開などもしている会社であるストックマークというところにお世話になっています。

2017年に証券会社に営業職として新卒入社して以来すでに 3社目です。

・~8月まで

BERT の初期検証や現在 toB 向けに展開しているサービスの新機能の基礎検証

・9月~11月くらいにかけて

新規リリースするプロダクトに配置換えになり、バッチ処理や、Django を用いたオンラインでの機械学習処理や、言語解析用の社内ツールの実装

・最近

新プロダクトにお客さんが付いてきて、顧客からいただいたテキストデータを解析して分類モデルを作ったり、それをプロダクト運用できるところまで作り込んだり、実際にお客さんのところに足を運び、生の声を聞く

などを実務でやってきました。

特に担当するサービスが変わってからは、kaggle で身につけた実装を知識をうまく活用できている感じがあるほか、いい感じのモデルできました!!・・だけではダメで、自分でサービスに組み込んでお客さんが扱えるところまでの作り込みまでやっており、ソフトウェアエンジニアとして広く成長できている実感もあり、業務面での満足度は高いです。一方でサービスローンチが近いこともあり、仕事が忙しく、平日にkaggle に割ける時間が少なかったりするのは%&@#$

入社して1年もせずにフロントエンド以外は概ねできるようになったのはベンチャーガリガリやれているおかげでしょうか。


3. まとめ

他にも寿司や spaggle やブログや登壇などの観点でも書けそうでしたがとりあえずここまで。

2019年は大きな飛躍の年でしたが、これも先人たちが引いてくれた高速道路のおかげです。

2020年はこれまで斜め読みしていた数式や理論面の補強、オフラインコンペでも後悔せずに力を出し切れるようなパイプラインの整備、ソロゴールドを目標に地に足つけてやっていきたいです。

以上、かえるるるでした。

BERT の事前学習タスク NSP と SOP の精度差を日本語の公開コーパスを用いて簡単に検証した。

こんにちはかえるるるです。

【この記事は自然言語処理アドベントカレンダー2019の 13日目の記事です。(https://qiita.com/advent-calendar/2019/nlp)】

2018年に BERT が出現して以来、今日まで BERT, XLNet, RoBERTa, ALBERT, T5 と Transformer ベースのモデルが精度の面で話題になりました。

その精度を担保しているポイントになっているのは事前学習という手法で、そこで学習したパラメータを使用すると汎用的なタスクにおいて高い精度をだせるようになっていることが知られています。

そこで今回は、BERT で使用された NextSentencePrediction と ALBERTの論文 でその代替として紹介された。 SentenceOrderPrediction (SOP) について簡単な検証を行いました。


もくじ

  1. 検証のモチベーション

  2. 検証の結果

  3. 検証の方法

  4. 感想

  5. まとめ


1 検証のモチベーション

ALBERT の論文では BERT のような Transformer ベースのモデルにとって、NSP はタスクとして簡単であったため、そのタスクを内包した上位互換タスクと言え る SOP を事前学習タスクとして解かせた結果、一定の精度向上が得られたと紹介されていました。タスクについての概要を読んだところ頭ではなるほど納得いき ましたが、実際に自分で手を動かしてみようと思った次第です。


2 検証の結果

val acc
NSP 0.972
SOP 0.964

微小ながら、 SOP の方が低い精度でした。


3 検証の方法

(本当はランダムな分散表現を使用し、スクラッチの Transformer モデルを使用するつもりでしたが、見知らぬバグを踏みまくって BertModel を継承したモデルを使用しました。)


4 感想

・BERT の事前学習データセットの作成方法がわかった

・(完装できなかったが、) BERT と Transformer の再現実装を試み、内容の理解が深まった。

・自力実装も継続して取り組んでいきたい


5 まとめ

・ 論文の主張の通り, 解きやすさは SOP > NSP だった。

・ 次は Transformer-Base のモデルを自力実装してこのタスクに取り組みたい。


以上です。ありがとうございました。

kaggle の discussion の upvote downvote 予測をしてみた

おはようございます、かえるるるです。

【 こちらは 「kaggle Advent Calendar 2019」 の6日目の記事です https://qiita.com/advent-calendar/2019/kaggle

皆さま、楽しい kaggle life を送れておりますでしょうか。

kaggle には Competition tier の他にもコミュニティへの貢献度に応じて Notebook, Discussion, Dataset でもメダルを獲得し、その数に応じた tier を獲得することができる制度があり、コンペの精度を競う以外の楽しみ方もできます。

特に discussion はコンペのスコア向上に有益な情報が多く共有されています。

kaggle を始めてそろそろ 1年、もらうばかりではなく自分も何かしらの情報を公開して、コミュニティ貢献したいところです。

しかしながら、豆腐メンタルの私が不得意ながらひねり出した英語力で綴った内容のポストに downvote がついてしまった暁には、悲しみのあまり、夜しか眠れなくなってしまいかねません。

そこで、前置きが長くなってしまいましたが、kaggle の discussion の upvote downvote 予測をしてみたという内容です。 高い精度で予測できれば、ポストする前に自作の model に予測させて、downvote 判定がつきそうなら言い回しを変えるなどの対策が取れるようになり、高い心理的な安全性を得ることができます。

参考までに、こちらの記事に先立って、先日行われた分析コンペLT会 で同じデータセットを使用し、別のタスクを解いてみたのですが、text のみではうまく学習させることができませんでした。(accuracy 51% ほど)

単語自体の意味というよりは、usefil links や数式が多かったりなどの情報や、誰が投稿したかや、そもそもそのスレッドが盛り上がっているかなどの要素が影響しそうだという示唆が得られましたので、本ポストでは、そのような特徴量も使用してみて text のみの場合とどのように差がつくか、見ていこうと思います。

もくじ

  1. 分析コンペ LT 会での LT発表の紹介 -- play with kaggle discussion's text data

  2. 実験の結果

  3. 実験の概要

 ・使用したデータセットについて

 ・前処理

 ・行った特徴量エンジニアリングについて

 ・学習について

  ・BERT -- only text feature

  ・Logistic Regression -- text and meta feature

  1. 考察

  2. まとめ

分析コンペ LT 会での LT発表の紹介

こちらに貼っておりますのでご覧になって(可能であればツッコミを入れて)いただけると嬉しいです :)

www.slideshare.net

実験の結果

まず今回行った実験の結果を示します

\ BERT Logistic Regression
val accuracy 0.5044483985765125 0.6040925266903915
test accuracy 0.5081850533807829 0.599288256227758
特徴量 テキストのみ テキスト + 特徴量エンジニアリング

今回私が行った実験では、テキストのみを入力とした BERT に 10ポイントもの大差をつけて、特徴量エンジニアリングを行いテキストとテキスト以外の特徴量も入力に使用した Logistic Regression のモデルがより upvote と downvote を予測できるという結果になりました。

実験の概要

i ) 使用したデータ

Meta kaggle で提供されているデータを使用しました。 こちらは、kaggle 運営が毎日アップデートしているデータセット群で、我々 kaggler にとっては馴染みの深い、コンペのメタ情報や、discussion の内容、誰から誰への発言かなどのデータがあります。

今回使用する discussion に関連するデータは forum ~ .csv という名前で格納されています。 複数テーブルに分割されてあったので、私の公開カーネル で結合処理 (カラムの説明などないので key は当てずっぽうなので変なところがあるかもです。。) を施して kaggle_discussion_df.csv として使用しました。

以下のようなイメージです。

カラム名 概要 使えそう感(実験前)
CompetitionsTitle タイトル -
ForumTitle タイトル 有力な情報がありそうだったら閲覧が増えそう
ForumTopicTitle タイトル 有力な情報がありそうだったら閲覧が増えそう
PostUserId ポストした人 GM はいいこと言いそう
Message コメント 重要
TotalViews 閲覧数 見てる人が多いと。。
Score upvote と downvote の差 今回の目的変数の作成に使用
Medal メダル スコアを加工して目的変数を作成するのでリーク対策で削除
TotalMessages 全メッセージ数(コンペごと) 結合処理間違えたかも
TotalReplies ぶら下がっている投稿数 有力
CompetitionTypeId 不明, 画像コンペ, NLP, tabular と思ったら違った
OnlyAllowKernelSubmissions コードコンペか否か 興味本位で採用
EvaluationAlgorithmAbbreviation 評価指標 複雑なメトリックだと 有益な書き込み増えそう
TotalTeams 全参加チーム数 人が多いと書き込みも増える
TotalCompetitors 全参加者数 同上
TotalSubmissions 全サブ数 同上

i i ) 目的変数とデータセット情報

・データ総数は 208569 件

df = df[df['Score']!=0].reset_index(drop=True)
df['up_or_down'] = np.where(df['Score'] > 0, 1, 0)

down = df[df['up_or_down']==0]
up = df[df['up_or_down']==1].sample(3500, random_state=SEED)

train = pd.concat([down, up], sort=True)

・score 0 が全データの 22% ほどあったので削除し、スコアがついたものだけを使用 ・そのうちマイナスのスコアがついたものが 3500件しかなかったため、正例はランダムダウンサンプルし、 合計 7025 件のデータで検証. ・train : val : test = 4496 : 1124 : 1405 で分割. (stratify=y)

i i i ) 前処理

・'< p >', '< / p >' などのタグ

・' & lt ;', '& gt ;' などの特殊文字

・'span', 'style', 'href' など正規表現芸の力不足ゆえ残ってしまったものたち

・"i'd": 'i would' の変換

・url の削除

を行いました。

i v ) 特徴量エンジニアリング

def make_features(df):
    df = df.fillna(" ")
    # 文字数や単語数 (情報量が多い) 方が upvote 多い?
    df['num_chars'] = df["Message"].apply(len)
    df['num_unique_words'] = df["Message"].apply(lambda comment: len(set(w for w in comment.split())))
    df['num_words'] = df["Message"].apply(lambda x: len(x.split()))
    # CV スコアや LB スコアを載せてるものは upvote が多い?
    df['num_num'] = df["Message"].progress_apply(lambda x: sum(x.count(w) for w in '1234567890'))
    for keyword in ['lb', 'cv', 'cross validation', 'fold']:
        df[keyword + '_num_words'] = df["Message"].progress_apply(lambda x: x.count(keyword))

    # useful links のような url を多く載せてるものも upvote が多い?
    df['num_urls'] = df["Message"].progress_apply(lambda x: get_url_num(x))
    return df

思いついた特徴量をえいやで実装しました。

また、Logistic Regression にテキスト特徴を投入する際、kaggle の NLP コンペのスターターコードによく見られるような、 Tfidf --> SVD で 200次元に圧縮したものを使用しました。

v ) 学習について

・BERT --> BERT learge, 最大系列長 220, BCEWithLogitsLoss

・Logistic Regression --> sklearn.linear_model のものでハイパーパラメータは公開カーネルから拝借してきました

考察

分析コンペLT 会でも考察したように単語の意味のみでは discussion に upvote がつくか downvote がつくかを判別するのが困難であるらしい。

一方で url を含むかどうかやテキストの長さ、また私がコンペに参加する中でキーワードっぽいと決め打った単語の含有数といった簡単な特徴を加えるだけでシンプルな LR が BERT よりも良い精度で予測できるようになりました。

以下は LR の coef です。

こちらのブログ によると、

ただ、プラスであればその特徴量は1の方向に、反対にマイナスなら0の方向に働く。

という事なので、私のお手製の温かみのある特徴量たちは、モデルが 0 (downvote) を予測するのに役に立ったようです。

f:id:kaeru_nantoka_py:20191206042801p:plain

まとめ

・discussion を見ていてもなぜ downvote がついているのだろうと思うような投稿もあるように、単語の持つ意味のみで vote の種類を予測するのは sota sota の実を食べた sota 人間の BERT でも困難なようです。

・一方で我々の日頃の感覚を反映した特徴量エンジニアリングで生み出した特徴量たちは見事にいい感じに決定境界を引くのに役立ってくれました。

・とはいえ精度 60 % では自信を持って upvote がつきそうな投稿に絞ってポストするといったことはできなさそうです。無念。

以上です、ありがとうございました🐸

[追記] ポスト公開後、Twitter などを通して、 「それ Leak してませんか??」のお声をいくつかいただきました。 具体的には、TotalViews, TotalMessages, TotalReplies などは、投稿をするタイミングでは得られない情報であり、今回の目的である 「投稿する時点で予測器にかけてポストするか否か判断する」には適さない処理でした。 今後ともどうぞよろしくお願いいたします。

PyTorch lightening で Titanic 問題解いてみた。

こんばんは、kaerururu (@kaeru_nantoka) です。

今回は、PyTorch lightening を使ってみたいと思ったので 以前公開したカーネル をベースに PyTorch lightening に書き換えたものを公開したのでその紹介をします。


目次


本ポストのモチベーション

公式のGitHub中の人の medium 記事 が充実しており、学習まではつまづくようなところはないと思います。

私も、学習までは既存の自分のコードをうまくテンプレートにはめ込むだけで行けました。しかし、本家 Tutorial 記事をはじめとした記事には推論部分まで書いておらず 「学習はできたけど predict できないぞ。。というか重みどこ?」とひとり沼にはまりそして無事自己解決したので、まとめようと思った次第です。


ソースコード

kaggle の 公開カーネルとして公開しております。 (https://www.kaggle.com/kaerunantoka/pytorch-lightening-for-titanic?scriptVersionId=19673008)

スコアが低いのは、パラメータ調整していないのと前処理を雑にしているからだと思います。もうちょっと行けるはず。。


ポイント

Entity Embedding とは、カテゴリ変数を同じベクトル空間上の座標にマッピングする手法でラベルエンコーディングよりカテゴリ変数同士の意味的な関係も表現できます。

[実装と説明]

cat_dims = [int(train[col].nunique()) for col in categorical_features]
emb_dims = [(x, min(50, (x + 1) // 2)) for x in cat_dims]

# def __init__() 内
# Embedding layers
self.emb_layers = nn.ModuleList([nn.Embedding(x, y)
                                     for x, y in emb_dims])

# def forward() 内
if self.no_of_embs != 0:
   x = [emb_layer(cat_data[:, i])
                  for i,emb_layer in enumerate(self.emb_layers)]
   x = torch.cat(x, 1)
   x = self.emb_dropout_layer(x)

上のように、nn.Embedding を利用し、カテゴリ変数を埋め込みます埋め込みの次元はカテゴリ変数のユニークな値の数の半分か、多すぎる場合は 50 に固定しています。

  • Predict

もう少しスッキリかけるかもしれないですが、とりあえず重みの保存場所とロード方法を docs のこのページ で確認することができたので、推論処理は普段私が使用している predict 関数を使用しました。

[実装と説明]

from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    filepath='/path/to/store/weights.ckpt',
    save_best_only=True,
    verbose=True,
    monitor='val_loss',
    mode='min'
)

trainer = Trainer(checkpoint_callback=checkpoint_callback)

このように Trainer 関数に checkpoint_callback を渡してあげると、

!ls ../input/weights.ckpt/

# _ckpt_epoch_5.ckpt 

のように、checkpoint_callback で指定した filepath に checkpoint が保存されます。

こちらを、

checkpoint = torch.load('../input/weights.ckpt/_ckpt_epoch_5.ckpt')

model.load_state_dict(checkpoint['state_dict'])

のようにロードし、その中の 'state_dict' を load_state_dict の引数に渡せば OK です。


まとめ

  • PyTorch lightening を使用してテーブルデーNLP の実装を、触れ込み通りスッキリ書くことができた。

  • 重みの保存場所、ロードの仕方がわかって推論もできた。

  • カテゴリ変数の前処理の仕方を少し変えた。

これをきっかけに PyTorch lightening を使って画像コンペ用のオレオレコードを整備していこうかなと思っています。

以上、ありがとうございました!