画像の半教師あり学習について整理した
概要
勉強会で画像の半教師あり学習について取り上げられるたびに、あれ、これ似たやつなかったっけ?と混乱するので、整理してみました。同じような内容のネット記事や資料はありますが、自分のために記載します。
半教師あり学習とは
半教師あり学習 (Semi-supervised learning: SSL) とはラベル付きデータとラベルなしデータで学習を行う方法。ラベルなしデータを活用してモデルのパフォーマンスをあげます。
MixMatch
MixMatchという半教師あり学習のアルゴリズムについて記載します。画像の分類問題を想定しています。
- ラベル付きデータ に対しData Augmentationで変換したデータ を作る
- ラベルなしデータ に対し 種類のData Augmentationで変換したデータ を作る
- ラベルなしデータで予測値 (モデルの出力値) を取得し、予測値平均 を計算する. ( 1 サンプル 個の予測値がでるので )
- Sharpen関数を用いて、予測値 の分布の温度を下げた値 を取得する.
・ は温度(ハイパーパラメータ)
- を推測ラベルとしたデータ とラベル付きデータ を合わせて のデータセットを作る
- と のMixUpで を作る
- と のMixUpで を作る
- のデータは、クロスエントロピー誤差を Loss関数 とする
- のデータは、平均二乗誤差を Loss関数 とする
- モデル全体の誤差は、として学習する ( はハイパーパラメータ)
[論文] MixMatch: A Holistic Approach to Semi-Supervised Learning
[1905.02249] MixMatch: A Holistic Approach to Semi-Supervised Learning
MixUp
MixMatchで使われるMixUpにいついて記載します。
Data Augmentation の一種です。
- トレーニングデータからランダムに2つのサンプル、データとラベルを取り出す
- データもラベルも以下のように混ぜて新しいデータを作成する
( はベータ分布からサンプリングした値)
[論文] mixup: Beyond Empirical Risk Minimization
[1710.09412] mixup: Beyond Empirical Risk Minimization
ReMixMatch
ReMixMatchは、MixMatchをさらに2つの新しいテクニックで改良した、半教師あり学習の方法です。
Distribution Alignment
1 つ目のテクニックは、ラベルなしデータの推論ラベル (モデルの出力値) の分布を調整する Distribution Alignment です.
MixMatchの中で、ラベルなしデータの推論ラベル にSharpen関数を用いる処理がありますが、その直前に以下の式で推論ラベル分布をデータセットのラベルの分布と同じになるように正規化します。
下の図と対応させると、 がラベルなしデータの推測ラベル (label guess) で、 が真のラベル (Ground-Truth labels) で、 がラベルなしデータの推測ラベルの移動平均(Model predictions). 移動平均は直前の128バッチの推測ラベル.
Augmentation Anchoring
2 つ目のテクニックは、Data Augmentationの強さの調整、Augmentation Anchoring です。
推論ラベルを取得するための、モデルのinputとなるラベルなしデータに対しては、弱いData Augmentationをかけます。下の図で言うと緑の部分。 実際に学習で使うデータ(MixUpで使うデータ)は、強い 種類のData Augmentationで変換したデータです。下の図で言うと青の部分。
また、この強いData Augmentationには、CTAugmentという手法を使います。
CTAugment
Data Augmentation の変換種類はランダムで決めますが、その強さ (変換用パラメータの値) が学習中に動的に調整されます。
- Data Augmentationにおける変換用パラメータの値をそれぞれ n 個にビニングします。
- ビニングされたパラメータの値に対応する各 weight をベクトル として表します
- の weight は学習前に 1 で初期化されます
- Data Augmentationのタイミングで、2種類の変換が選ばれます。変換用パラメータの値 はこの を使って選ばれるのですが、weight が0.8を下回ったものに対応する変換用パラメータの値は使われません。それ以外のweightがカテゴリカル分布に変換されて、選ばれます。
- weightの更新の方法は、モデルの予測とラベルがどの程度一致するかを以下の式で表し、その一致度を使って更新します。
一致度
更新式 ( = 0.99 )
[論文] ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring:
[1911.09785] ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring
FixMatch
FixMatchはPseudo-LabelとConsistency Regularizationを使った半教師あり学習です。
Pseudo-Label
弱いData Augmentationで変換した画像をモデルのinputにして、出力の中で確信度の一番高いラベルでハードラベリングし、Pseudo-Label (疑似ラベル)とします。
Consistency Regularization
Consistency Regularizationとはラベルなしデータの画像にノイズを加えても、モデルの出力値が変わらないようにする方法です。一般的には以下のような項をLoss関数に付け加えます。
FixMatchでは、Pseudo-Label と強い Data Augmentation をかけたデータのモデル出力値を使ってConsistency Regularization を行います。
全体的なアルゴリズムは以下のとおりです。
- ラベルありデータのLoss関数 は通常のクロスエントロピーLoss関数 です。
・ は弱い Data Augmentation
・ はラベル
- ラベルありデータのLoss関数 を以下のようにします。
・ は強いData Augmentation
・ は上記で記載した方法で決定された疑似ラベル
・ は閾値
- 全体のLoss関数は となります。
[論文] FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
[2001.07685] FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
その他
半教師あり学習で気になっていた関連事項について少し調べました。
VAT
VAT(Virtual Adversarial Training)とは、ラベル分布を滑らかにすることによって、半教師あり学習で正則化を行う方法です。
- ラベルなしデータ のモデルの出力値 (推測ラベル)と、そのデータに摂動を加えたデータのモデルの出力値 の 2 つの分布の差異を以下のように表し、正則化項とします。
・ は非負の差異の値を返す関数でクロスエントロピーなどです
・ 摂動 は2つの分布の差異がもっとも大きくなる摂動です。(ただし で はハイパラ )
- 上記の を使ってLoss関数を以下のように定めます。
・ はラベルつきデータのクロスエントロピーLoss。
・ はそれぞれ、ラベル付きデータ数とラベルなしデータ数
この図は、半教師あり学習でVATを用いた時のモデルの予測値(上段)と、LDSの値 (下段)のシミュレーションです。緑とピンクがそれぞれのラベルの値で、グレーがラベルが付いていないデータです。
境界部分が徐々に良くなっているのがわかります。
[論文] Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning
[1704.03976] Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning
UDA
UDA (Unsupervised Data Augmentation) は、ラベルなしデータに Data Augmentation をかけたデータ のモデル出力値 と、変換しなかったラベルなしデータ のモデル出力値 をなるべく同じにして学習する方法です。FixMatchのアイディアの元になった手法です。
最終的な目的関数は以下の通りで、図のSupervised Cross-entropy Loss が 1 項目で、 図の Unsupervised Consistency Loss が 2 項目です。
[論文] Unsupervised Data Augmentation for Consistency Training
[1904.12848] Unsupervised Data Augmentation for Consistency Training
RandAugment
RandAugmentとは、Data Augmentationの 1 つで自動で変換をかける手法です。
計算コストが低いことが特徴です。
- Data Augmentationの種類は以下のK(=14)種類からランダムに選ばれます。
- 強さの探索に関しては、パラメータの値をそれぞれ の整数にスケーリングしておき、全てData Augmentationの変換で同じ値(スケーリング後の値が同じ)を使います。値の決め方は、ランダム・固定・線形増加・上限が増加していくランダムサンプリング、という4種類で実験し、どれも精度に大差なしとのこと。
[論文] RandAugment: Practical automated data augmentation with a reduced search space
[1909.13719] RandAugment: Practical automated data augmentation with a reduced search space