Language Interpretability Tool (LIT) の紹介
概要
Google Researchが、言語解釈ツール Language Interpretability Tool (LIT) を紹介する論文を出しました。NLPモデルが期待どおりに動作しない場合に、何が問題かを解明するために役立つツールだと記載されていて、便利そうだと思い試しに動かしてみたので、LITの簡単な紹介を記載します。
インストール
GitHubのreadmeの記載のとおり、condaで環境を作ります。
LITはフロントがTypeScript、バックエンドがPythonで作られています。
git clone https://github.com/PAIR-code/lit.git ~/lit cd ~/lit conda env create -f environment.yml conda activate lit-nlp conda install cudnn cupti # optional, for GPU support conda install -c pytorch pytorch # optional, for PyTorch cd ~/lit/lit_nlp/client yarn && yarn build
LITの起動
インスタンスの起動
インスタンス起動用のスクリプトをPythonで作り、ポート指定で起動します。
LITのGitHubリポジトリにいくつかサンプルが用意されているので、まず quickstart_sst_demo を起動してみました。このサンプルは、はじめにfine tuningが走るので、起動までにGPUで5分ほどかかります。
python -m lit_nlp.examples.quickstart_sst_demo --port=5432 [optional --args]
起動したら、http://localhost:5432 にアクセス。
インスタンス起動用のスクリプト作成
インスタンス起動用のスクリプトを作成するには、LITのDatasetクラスのサブクラス、Modelクラスのサブクラスを作成し、そのインスタンスを設定したdictをdev_server.Serverの引数に設定してあげる必要があります。
from lit_nlp import dev_server models = {'foo': FooModel(...)} datasets = {'baz': BazDataset(...)} server = lit_nlp.dev_server.Server(models, datasets, port=4321) server.serve()
Design Overview
Datasetクラス
- lit_dataset.Datasetのサブクラスとして作成.
- spec 関数を定義します. データセットのカラムに対して、適切なType情報を設定して、dictで返す必要があります.
- self._examplesにデータセットをdictのリストとして設定する必要があります
class BazDataset(lit_dataset.Dataset): LABELS = ['0', '1'] def __init__(self, path, split: str): self._examples = [] for ex in load_tfds(path, split=split): self._examples.append({ 'sentence': ex['sentence'].decode('utf-8'), 'label': self.LABELS[ex['label']], }) def spec(self): return { 'sentence': lit_types.TextSegment(), 'label': lit_types.CategoryLabel(vocab=self.LABELS) }
Modelクラス
- lit_model.Modelのサブクラスとして作成
- input_spec 関数を定義します. インプットのカラムに対して、適切なType情報を設定して、dictで返す必要があります
- output_spec 関数を定義します. モデルの予測と追加情報に対して、適切なType情報を設定して、dictで返す必要があります. 追加情報にはAttention HeadsやEmbeddingsなどが設定でき、LITでのComponentsのインプットになります.
- predict / predict_minibatch 関数を定義します. input_specの記載と同一のTypeの入力データに対し、前処理を施してからモデルで予測した値を返す必要があります.
class FooModel(Model): LABELS = ['0', '1'] def __init__(self, model_path, **kw): self._model = _load_model(model_path, **kw) def predict(self, inputs: List[Input]) -> Iterable[Preds]: examples = [self._model.convert_dict_input(d) for d in inputs] return self._model.predict(examples) def input_spec(self): return { 'sentence': lit_types.TextSegment(), 'label': lit_types.CategoryLabel(vocab=self.LABELS, required=False) } def output_spec(self): return { 'probas': lit_types.MulticlassPreds(vocab=self.LABELS, parent='label'), }
LITの機能
Embeddings
EmbeddingsをPCA・UMAPで次元削減した結果を3Dで描画されるエリアです。
インタラクティブに探索できるので、あるクラスだけで絞り込みして他のサンプルと距離が離れているデータを見つけたりするのに使えそう。予測が正しいものを青・間違っているものを赤にした label 1のデータを可視化してみました。
Prediction Score
モデルの予測が表示されるエリアです。
2値分類タスクにおいて、閾値を変動させると Metricsの値やConfusion Matrixの値も連動して変化します。データを選択して詳細をData Tableで見ることもできます。
マスクされた単語予測のタスクにおいて、どの単語をマスクするとモデルは何を予測するのかが確率と共に表示されます。
Span Labelingなどの構造化予測のタスクにおいて、ラベルと予測されたタグが表示されます。
Explanations
モデルの判断根拠の説明に用いられる手法であるLIMEや勾配を用いたヒートマップや、Attention Headの可視化が表示されたエリアです。Data Tableで選択されたデータが表示されます。
Datapoint Generator
以下のアルゴリズムでデータを新しく生成することができます。
- Scrambler: 単語をランダムに並び替える
- Backtranslation: 逆翻訳
- Word replacer: 文字の置き換え
- Hotflip : 分類タスクにいおいて、予測に最も影響をあたえるトークンを、反対の影響をあたえるトークンに変更
Performance
評価指標、混同行列が表示されるエリアです。
予測値のTypeによって自動で表示される指標は決まりますが、カスタマイズすることも可能です. 複数モデルを起動スクリプトで設定しておくとモデルごとに評価指標が表示されます。
データをlabelで切り分けて評価指標を表示させたり、Data Tableで選択したデータだけの評価指標を表示させることができます。
動画
公式から 3 分間のDemo動画が公開されています.