画像の半教師あり学習について整理した

概要

勉強会で画像の半教師あり学習について取り上げられるたびに、あれ、これ似たやつなかったっけ?と混乱するので、整理してみました。同じような内容のネット記事や資料はありますが、自分のために記載します。

教師あり学習とは

教師あり学習 (Semi-supervised learning: SSL) とはラベル付きデータとラベルなしデータで学習を行う方法。ラベルなしデータを活用してモデルのパフォーマンスをあげます。

MixMatch

MixMatchという半教師あり学習アルゴリズムについて記載します。画像の分類問題を想定しています。

f:id:YukoIshizaki:20200504151303p:plain:w500

  1. ラベル付きデータ    \mathcal{X} に対しData Augmentationで変換したデータ   \hat{\mathcal{x}} \in   \hat{\mathcal{X}} を作る
  2. ラベルなしデータ   \mathcal{U} に対し K 種類のData Augmentationで変換したデータ   \hat{\mathcal{u}} \in   \hat{\mathcal{U}} を作る
  3. ラベルなしデータで予測値 (モデルの出力値) を取得し、予測値平均   \overline{q} を計算する. ( 1 サンプル  K 個の予測値がでるので )
  4. Sharpen関数を用いて、予測値   \overline{q} の分布の温度を下げた値  {q} を取得する.
     Sharpen(\overline{q},T)=\frac{\displaystyle {\overline{q}}_i^{\frac{1}{T}}}{ \sum_{j=1}{{\overline{q}}_j^{\frac{1}{T}}} }

     Tは温度(ハイパーパラメータ)
     
  5.  {q} を推測ラベルとしたデータ  \hat{\mathcal{U}} とラベル付きデータ  \hat{\mathcal{X}} を合わせて  {\mathcal{W}} のデータセットを作る
  6.  \hat{\mathcal{X}} {\mathcal{W}} のMixUpで  {\mathcal{X'}} を作る
  7.  \hat{\mathcal{U}} {\mathcal{W}} のMixUpで  {\mathcal{U'}} を作る
  8.  {\mathcal{X'}} のデータは、クロスエントロピー誤差を Loss関数  {\mathcal{L}}_{\mathcal{X}}とする
  9.  {\mathcal{U'}} のデータは、平均二乗誤差を Loss関数  {\mathcal{L}}_{\mathcal{U}}とする
  10. モデル全体の誤差は、 {\mathcal{L}} = {\mathcal{L}}_{\mathcal{X}} +{\mathcal \lambda_{\mathcal{U}}} {\mathcal{L}}_{\mathcal{U}}として学習する ( {\mathcal \lambda_{\mathcal{U}}} はハイパーパラメータ)

f:id:YukoIshizaki:20200504160552p:plain:w500

[論文] MixMatch: A Holistic Approach to Semi-Supervised Learning
[1905.02249] MixMatch: A Holistic Approach to Semi-Supervised Learning

MixUp

MixMatchで使われるMixUpにいついて記載します。
Data Augmentation の一種です。

  1. レーニングデータからランダムに2つのサンプル、データとラベルを取り出す  (X_i, y_i), (X_j, y_j)
  2. データもラベルも以下のように混ぜて新しいデータを作成する
     \tilde{X} = \lambda X_i + (1 - \lambda) x_j
     \tilde{y} = \lambda y_i + (1 - \lambda) y_j
    (   \lambda はベータ分布からサンプリングした値)

[論文] mixup: Beyond Empirical Risk Minimization
[1710.09412] mixup: Beyond Empirical Risk Minimization

ReMixMatch

ReMixMatchは、MixMatchをさらに2つの新しいテクニックで改良した、半教師あり学習の方法です。

Distribution Alignment

1 つ目のテクニックは、ラベルなしデータの推論ラベル (モデルの出力値) の分布を調整する Distribution Alignment です.

MixMatchの中で、ラベルなしデータの推論ラベル   \overline{q} にSharpen関数を用いる処理がありますが、その直前に以下の式で推論ラベル分布をデータセットのラベルの分布と同じになるように正規化します。

 \tilde{q} = Normalize  \left( \frac{\displaystyle q \times p(y)}{ \displaystyle \tilde{p}(y)}  \right)

 Normalize(x) = \displaystyle \frac{x_i}{\displaystyle \sum_{j}{x_j}}

下の図と対応させると、 q がラベルなしデータの推測ラベル (label guess) で、 {p}(y) が真のラベル (Ground-Truth labels) で、 \tilde{p}(y) がラベルなしデータの推測ラベルの移動平均(Model predictions). 移動平均は直前の128バッチの推測ラベル.

f:id:YukoIshizaki:20200504230334p:plain:w350

Augmentation Anchoring

2 つ目のテクニックは、Data Augmentationの強さの調整、Augmentation Anchoring です。

推論ラベルを取得するための、モデルのinputとなるラベルなしデータに対しては、弱いData Augmentationをかけます。下の図で言うと緑の部分。 実際に学習で使うデータ(MixUpで使うデータ)は、強い K 種類のData Augmentationで変換したデータです。下の図で言うと青の部分。

f:id:YukoIshizaki:20200504233748p:plain:w350

また、この強いData Augmentationには、CTAugmentという手法を使います。

CTAugment

Data Augmentation の変換種類はランダムで決めますが、その強さ (変換用パラメータの値) が学習中に動的に調整されます。

  1. Data Augmentationにおける変換用パラメータの値をそれぞれ n 個にビニングします。
  2. ビニングされたパラメータの値に対応する各 weight をベクトル  m として表します
  3.  m の weight は学習前に 1 で初期化されます
  4. Data Augmentationのタイミングで、2種類の変換が選ばれます。変換用パラメータの値 はこの m を使って選ばれるのですが、weight が0.8を下回ったものに対応する変換用パラメータの値は使われません。それ以外のweightがカテゴリカル分布に変換されて、選ばれます。
  5. weightの更新の方法は、モデルの予測とラベルがどの程度一致するかを以下の式で表し、その一致度を使って更新します。
    一致度  \omega = 1 -  \frac{1}{2L}  \sum{|prediction - label|}
    更新式  m_i = \rho m_i  + (1- \rho) \omega  (  \rho = 0.99 )


[論文] ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring:
[1911.09785] ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring

FixMatch

FixMatchはPseudo-LabelとConsistency Regularizationを使った半教師あり学習です。

Pseudo-Label

弱いData Augmentationで変換した画像をモデルのinputにして、出力の中で確信度の一番高いラベルでハードラベリングし、Pseudo-Label  \hat{q_b} (疑似ラベル)とします。

Consistency Regularization

Consistency Regularizationとはラベルなしデータの画像にノイズを加えても、モデルの出力値が変わらないようにする方法です。一般的には以下のような項をLoss関数に付け加えます。

 ||P_{model} (y, Augment(x); \theta )- P_{model} (y, Augment(x) ; \theta )||_2^2

FixMatchでは、Pseudo-Label と強い Data Augmentation をかけたデータのモデル出力値を使ってConsistency Regularization を行います。

f:id:YukoIshizaki:20200505111030p:plain:w500
全体的なアルゴリズムは以下のとおりです。

  1. ラベルありデータのLoss関数  {\mathscr{l}}_s は通常のクロスエントロピーLoss関数  H です。

     {\mathscr{l}}_s =   \displaystyle\frac{1}{B} \displaystyle \sum_{b=1}^{B}{H(p_b, p_{model}(y| \alpha(x_b) ) )}

      \alpha は弱い Data Augmentation
     p_b はラベル
     
  2. ラベルありデータのLoss関数  {\mathscr{l}}_uを以下のようにします。

      {\mathscr{l}}_u =  \displaystyle\frac{1}{\mu B}  \sum_{b=1}^{mu B}\mathbb{1} \left( max(q_b) \geq  \tau \right) {H(q_b, p_{model}( \hat{q_b}| {\mathcal{A}} ({\mathcal{u}}_b) ) )}

     {\mathcal{A}} は強いData Augmentation
     \hat{q_b}は上記で記載した方法で決定された疑似ラベル
     \tau閾値
     
  3. 全体のLoss関数は  {\mathscr{l}}_s+ \lambda_{u}{\mathscr{l}}_u となります。

[論文] FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
[2001.07685] FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

その他

教師あり学習で気になっていた関連事項について少し調べました。

VAT

VAT(Virtual Adversarial Training)とは、ラベル分布を滑らかにすることによって、半教師あり学習正則化を行う方法です。

  1. ラベルなしデータ x のモデルの出力値 (推測ラベル)と、そのデータに摂動 rを加えたデータのモデルの出力値 の 2 つの分布の差異を以下のように表し、正則化項とします。

     LDS = D \bigl[  p(y|x;  \hat{\theta}), p(y|x + r_{qadv};  \theta) {\bigr]}

    D は非負の差異の値を返す関数でクロスエントロピーなどです
    ・ 摂動  r_{qadv} は2つの分布の差異がもっとも大きくなる摂動です。(ただし  ||r||_2 \leq \epsilon \epsilon はハイパラ )
     
  2. 上記の  LDS を使ってLoss関数を以下のように定めます。

     {\mathscr{l}}(D_l, \theta) + \alpha  \displaystyle \frac{1}{N_l + N_{ul}}  \displaystyle \sum_{x_* \in D_l, D_{ul}} LDS(x*, \theta)

     {\mathscr{l}}(D_l, \theta)はラベルつきデータのクロスエントロピーLoss。
     N_l, N_{ul} はそれぞれ、ラベル付きデータ数とラベルなしデータ数


f:id:YukoIshizaki:20200505151401p:plain:w500

この図は、半教師あり学習でVATを用いた時のモデルの予測値(上段)と、LDSの値 (下段)のシミュレーションです。緑とピンクがそれぞれのラベルの値で、グレーがラベルが付いていないデータです。
境界部分が徐々に良くなっているのがわかります。

[論文] Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning
[1704.03976] Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning

UDA

UDA (Unsupervised Data Augmentation) は、ラベルなしデータに Data Augmentation をかけたデータ  \hat{x} のモデル出力値  p_{\theta}(y|{\hat{x}}) と、変換しなかったラベルなしデータ  x のモデル出力値  p_{\tilde{\theta}}(y|{x}) をなるべく同じにして学習する方法です。FixMatchのアイディアの元になった手法です。


f:id:YukoIshizaki:20200505220733p:plain:w550

最終的な目的関数は以下の通りで、図のSupervised Cross-entropy Loss が 1 項目で、 図の Unsupervised Consistency Loss が 2 項目です。

 \min_{\theta} {\mathcal{J}}(\theta) =   \mathbb{E}_{x, y^{*} \in L}   {\bigl[}  - \log p_{\theta} (y^{*} \mid x) {\bigr]} + \lambda \mathbb{E} _{x \in U} \mathbb{E} _{\hat{x} \sim q(\hat{x} \mid x) }  {\bigl[}   {\mathcal{D}}_{KL}  ( p_{\tilde{\theta}}(y  \mid  {x}) ||   p_{{\theta}}(y \mid  { \hat{x}}) {\bigr]}

[論文] Unsupervised Data Augmentation for Consistency Training
[1904.12848] Unsupervised Data Augmentation for Consistency Training

RandAugment

RandAugmentとは、Data Augmentationの 1 つで自動で変換をかける手法です。
計算コストが低いことが特徴です。

  1. Data Augmentationの種類は以下のK(=14)種類からランダムに選ばれます。f:id:YukoIshizaki:20200505180522p:plain:w400
  2. 強さの探索に関しては、パラメータの値をそれぞれ  0\sim10 の整数にスケーリングしておき、全てData Augmentationの変換で同じ値(スケーリング後の値が同じ)を使います。値の決め方は、ランダム・固定・線形増加・上限が増加していくランダムサンプリング、という4種類で実験し、どれも精度に大差なしとのこと。

[論文] RandAugment: Practical automated data augmentation with a reduced search space
[1909.13719] RandAugment: Practical automated data augmentation with a reduced search space

終わり

自分なりに半教師あり学習について整理してみました。間違いがあれば、コメントでご指摘いただけたら嬉しいです。
書いてる途中で気づいたのですが、kaggleのコンペではSSLはあまり使われなさそうですね...