【Python】制限ボルツマンマシン(RBM)の実装と理論|PyTorchによるGPU化
【Python】制限ボルツマンマシン(RBM)の実装と理論|PyTorchによるGPU化

【Python】制限ボルツマンマシン(RBM)の実装と理論|PyTorchによるGPU化

生成モデルや事前学習で使用される制限ボルツマンマシンの基本的な理論とPythonによる実装を本記事では説明しています。制限ボルツマンマシンの学習は、多くの近似的手法が使用され混乱する方も多いと思いますが、その混乱を可能な限り避けるように説明しました。

\begin&w_^ \leftarrow w_^ + \eta \big( \mathbb_>[v_ h_] – \mathbb_>[v_ h_] \big) \\ &b_^ \leftarrow b_^ + \eta \big(\mathbb_>[v_] – \mathbb_>[v_] \big) \\ &c_^ \leftarrow c_^ + \eta \big( \mathbb_>[h_] – \mathbb_>[ h_] \big) \end

また、CD法で勾配を計算する際は、隠れ変数のサンプル値\(\mathbf^\)を使用するのではなく、条件付き確率\(P(\mathbf \mid \mathbf^, \theta)\)を使用するのが良いと言われています。

\begin&w_^ \leftarrow w_^ + \eta \big( v_^ P(h_=1 \mid \mathbf^, \theta)~ – v_^ P(h_=1 \mid \mathbf^, \theta) \big) \\ &b_^ \leftarrow b_^ + + \eta \big( v_^ – v_^ \big) \\ &c_^ \leftarrow c_^ + \eta \big(P(h_=1 \mid \mathbf^, \theta)~ – P(h_=1 \mid \mathbf^, \theta) \big) \end

パーシステント・コンストラスティブ・ダイバージェンス法(PCD法)

パーシステント・コンストラスティブダイバージェンス法(Persistent Constrastive Divergence)法は、CD法が定常状態からのサンプリングができないという問題を部分的に解決した方法でCD法以上の効率・精度を経験的に確保することができると言われています。

k-PCD法のイメージ

制限ボルツマンマシンの誤差関数

擬似対数尤度(Pseudo-likelihood) $$\text() \equiv \log \prod_^ p(v_ | _) = \sum_^ \log p(v_ | _)$$

制限ボルツマンマシンの場合、\(p(v_|_)\)は、以下のように表せます。

$$p(v_ | _ ) = \sigma \big(F(v_=0, _) ~ – F(v_=1, _) \big) $$
  • \( \tau \)は、各データの要素に関する一様分布を持つ確率変数
  • \(\hat>_ \)は、\(\tau\)で指定されるユニットを反転したもの (e.g. 0 → 1, 1 → 0)
  • Reconstruction error
  • Annealed Importance Sampling
  • Validationデータとtrainingデータの間のaverage自由エネルギー差

PytorchによるRBMの実装

【10分完了】Google Colaboratoryのインストール法・使い方 Google Colabは、GPUを使用できるため、低コストかつ高速で機械学習の実装を行う際に必要不可欠なツールです。本記事では、Google Colabのインストール方法と注意事項を10〜15分程度でまとめました。.

必要なライブラリをインポート import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns from tqdm.notebook import tqdm import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F # MNIST Datasetの取得 from keras.datasets import mnist # deviceの設定 (cpu or gpu) device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

また、 tqdm はプログレスバーを表示するためにインポートしました。

また、 keras.datasets の mnist は、MNIST(手書き数字)データセットを取得するために使用します。

サンプルデータを読み込む (X_train, y_train), (X_test, y_test) = mnist.load_data()

今回は X_train のみを使用するため、 X_train を一次元配列に変換して、値を0と1の離散値に変換します。

X_train = X_train.reshape(60000, 784) # 1次元に変換 X_train = X_train.astype('float32') # float32型に変換 X_train /= 255 # 0.0-1.0に変換 X_train = np.where(X_train

次に、 X_train をpytorchのTensorにへ

# 複数の画像を表示する関数 def check_images(x): images = np.rollaxis(np.rollaxis(x[0:100].reshape(20, -1, 28, 28), 0, 2), 1, 3).reshape(-1, 20 *28) plt.figure(figsize=(10,20)) plt.imshow(images, cmap='gray') plt.grid(False) plt.axis('off') plt.show() # 画像を一枚表示する関数 def check_one_image(x): image = x[0].reshape(28, 28) plt.figure(figsize=(10,20)) plt.imshow(image, cmap='gray') plt.grid(False) plt.axis('off') plt.show() check_images(X_train) データをPyTorch専用のDatasetを変換するコード

ここからは、PyTorchでRBMを実装するために、PyTorch専用のTensorという配列に変換して、 Dataset を作成します。

class MyDataset(torch.utils.data.Dataset): def __init__(self, samples): self.samples = samples def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] return sample
  • 【Pytorch】tensor型とは|知らないとまずいものをまとめてみた
  • PytorchのDatasetを徹底解説(自作データセットも作れる)
  • 【徹底解説】PytorchのDataLoaderの使い方
RBMの実装 class RBM(nn.Module): def __init__(self, vis_dim, hid_dim, initial_std=0.01, device='cpu'): super(RBM, self).__init__() self.device = device self.b = torch.zeros(1, vis_dim, device=device) self.c = torch.zeros(1, hid_dim, device=device) self.w = torch.empty((hid_dim, vis_dim), device=device).normal_(mean=0, std=initial_std) def _visible_to_hidden(self, v): """可視ユニットから隠れユニットをサンプル """ p = torch.sigmoid(F.linear(v, self.w, self.c)) return p.bernoulli() def _hidden_to_visible(self, h): """隠れユニットから可視ユニットをサンプル """ p = torch.sigmoid(F.linear(h, self.w.t(), self.b)) return p.bernoulli() def _visible_to_ph(self, v): """P(h=1|v)を計算 """ return torch.sigmoid(F.linear(v, self.w, self.c)) def sample(self, v, gib_num=1): """データをサンプリング """ v = v.view(-1, self.w.size(1)).to(self.device) h = self._visible_to_hidden(v) for _ in range(gib_num): v_gibb = self._hidden_to_visible(h) h = self._visible_to_hidden(v_gibb) return v_gibb def sample_ph(self, v, gib_num=15): """phをサンプリング """ v = v.view(-1, self.w.size(1)).to(self.device) ph = self._visible_to_ph(v) h = ph.bernoulli() # Gibbs Sampling 1 ~ k for _ in range(gib_num): v_gibb = self._hidden_to_visible(h) ph_gibb = self._visible_to_ph(v_gibb) h = ph_gibb.bernoulli() return ph_gibb def energy(self, v): """エネルギーを計算 """ v_term = torch.matmul(v, self.b.t()) w_x_h = torch.matmul(v, self.w.t())+self.c h_term = torch.sum(F.softplus(w_x_h), dim=1) return -h_term-v_term def pseudo_likelihood(self, v): """疑似対数尤度を計算 """ flip = torch.randint(0, v.size(1), (1,)) v_fliped = v.clone() v_fliped[:, flip] = 1-v_fliped[:, flip] energy = self.energy(v) energy_fliped = self.energy(v_fliped) return v.size(1)*F.softplus(energy_fliped - energy) def _update(self, v_pos, lr=0.1): """ミニバッチあたりの学習更新 """ # positive part ph_pos = self._visible_to_ph(v_pos) # negative part v_neg = self._hidden_to_visible(self.h_states) ph_neg = self._visible_to_ph(v_neg) lr = lr/v_pos.size(0) # Update W update = torch.matmul(ph_pos.t(), v_pos) - torch.matmul(ph_neg.t(), v_neg) self.w += lr*update self.b += lr*torch.sum(v_pos - v_neg, dim=0) self.c += lr*torch.sum(ph_pos - ph_neg, dim=0) # PCDのために隠れユニットの値を保持 self.h_states = ph_neg.bernoulli() def fit(self, data, n_epoch=10, lr=1e-1, batch_size=128): train = MyDataset(data[:int(len(data)*0.7)]) test = MyDataset(data[int(len(data)*0.7):]) train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size, shuffle=True, num_workers=0) test_loader = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size, shuffle=True, num_workers=0) train_loss_avg, val_loss_avg = [], [] # pcd memory self.h_states = torch.zeros(batch_size, self.w.size(0), device=device) for epoch in tqdm(range(n_epoch)): train_loss_avg.append(0) val_loss_avg.append(0) self.train() for i, data in enumerate(train_loader): data = data.to(self.device) self._update(data) train_loss_avg[-1] += - self.pseudo_likelihood(data).mean().item() train_loss_avg[-1] /= data.size(1) self.eval() with torch.no_grad(): for i, data in enumerate(test_loader): data = data.view(-1, self.w.size(1)).to(self.device) val_loss_avg[-1] += - self.pseudo_likelihood(data).mean().item() val_loss_avg[-1] /= data.size(1) print(f"[EPOCH]: , [LOSS]: , [VAL]: ") return train_loss_avg, val_loss_avg MNISTを使用した数値実験 # RBMインスタンスの作成 model = RBM(28*28, 256, device=device) # 学習の実行 train_loss, test_loss = model.fit(data, n_epoch=100, lr=1e-1, batch_size=100) [EPOCH]: 1, [LOSS]: -1946.8083, [VAL]: -2175.9491 [EPOCH]: 2, [LOSS]: -2300.3520, [VAL]: -2301.9908 [EPOCH]: 3, [LOSS]: -2367.6407, [VAL]: -2466.6865 [EPOCH]: 4, [LOSS]: -2483.8336, [VAL]: -2523.9067 [EPOCH]: 5, [LOSS]: -2538.9095, [VAL]: -2715.3491 [EPOCH]: 6, [LOSS]: -2562.1427, [VAL]: -2594.0387 [EPOCH]: 7, [LOSS]: -2766.7965, [VAL]: -2700.4362 [EPOCH]: 8, [LOSS]: -2789.4425, [VAL]: -2731.5312 [EPOCH]: 9, [LOSS]: -2749.9166, [VAL]: -2694.6613 [EPOCH]: 10, [LOSS]: -2823.9816, [VAL]: -2704.1976 : : : fig, ax = plt.subplots() ax.plot(train_loss, marker="o", label="train") ax.plot(test_loss, marker="o", label="test") ax.set_xlabel('Epoch', fontsize=30) ax.set_ylabel(r'$- \hat/N$', fontsize=30) ax.legend(fontsize=20) ax.grid() plt.show() init_state = data[:100].to(device) # データを生成 sample = model.sample(init_state, gib_num=200) # 生成画像を表示 check_images(sample.detach().cpu().numpy())

参考資料

参考文献 ボルツマンマシン (シリーズ 情報科学における確率モデル 2) 機械学習スタートアップシリーズ これならわかる深層学習入門 深層学習 (機械学習プロフェッショナルシリーズ) 参考論文
  • A Practical Guide to Training Restricted Boltzmann Machines
  • Training Restricted Boltzmann Machines using Approximations to the Likelihood Gradient

まとめ

【運営者】 : 東大で理論物理を研究中(経歴)東京大学, TOEIC950点, NASA留学, カナダ滞在経験有り, 最優秀塾講師賞, オンライン英会話講師試験合格, ブログと独自コンテンツで収益6桁達成 【編集者】: イングリッシュアドバイザーとして勤務中(経歴)中学校教諭一種免許取得[英語],カナダ留学経験あり, TOEIC650点

Python学習を効率化させるサービス
  • 【レベル別】Pythonを学ぶための本を厳選しました|入門〜上級者まで
  • 【レベル別】Pythonで機械学習を学ぶための本|最短実務応用を目指せ!
こちらの記事もオススメ Science Lab

【入門】WassersteinGANの理論を解説

2021年7月20日 努力のガリレオ

【Python】重点サンプリング(Importance Sampling)の理論と実.

2024年3月4日 努力のガリレオ Science Lab

【Python】ギブスサンプリングの応用(多変量ガウス分布)

2021年7月25日 努力のガリレオ

【簡単】主成分分析の理論とPythonによる実装(寄与率も計算)|アニメーション付き

2021年12月6日 努力のガリレオ Science Lab

【入門】生成モデルと統計的機械学習について

2020年12月25日 努力のガリレオ 【暴露】ブログを始める6つのメリット・3つのデメリット|東大生が解説! 【東大生が解説】勉強ブログで収益化(月10万円稼ぐ5ステップ)| ノートの代わりにブログにまとめて収益化! 【東大生が解説】大学生がブログで月10万円稼ぐための15ステップ 【最新】ブログの始め方ガイド|ブログ作成方法から収益化まで解説! 【初心者】趣味ブログで月10万円稼ぐ方法|好きなことを書いて稼ぐ!

【経営者/編集者】 : 東京大学院籍/TOEIC950点/最優秀塾講師賞/オンライン英会話講師試験合格/カナダ滞在【デザイナー/編集者】 : カナダ滞在経験/英語教員経験あり/日本語教育課程 ●ブログ運営、英語、プログラミング、機械学習、物理数学を中心に発信します。 ● 努力のガリレオであり続けます。