deeplearning

Pytorchを使った線形回帰をGoogle Colaboratoryで実装してみる

機械学習ライブラリのpytorchを用いて簡単な線形回帰を実装してみます。

Pytorchとは?

Facebookの人工知能研究グループにより初期開発された 、Python向けのオープンソース機械学習ライブラリです。読み方はパイトーチです。

Deep Learningライブラリの中でも近年人気が出てきています。非常に簡単に記述できるのでおすすめです。

Google Colaboratoryとは?

Googleが提供している、無料で使用できるpython実行環境です。

GPUも使用でき機械学習に必要なライブラリもあらかじめインストールされています。ローカルで機械学習の環境を構築するのは大変なので、今回はGoogle Colaboratoryを使用していこうと思います。

使い方は下記参照。1分程で使えるようになります。

線形回帰を実装してみる

必要なライブラリをインストールします。

import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np

ランダムな分布モデルを作成してプロットしてみましょう。
この図に対して線形回帰します。

x = torch.randn(100, 1) * 10
y = x + torch.randn(100, 1) * 3
plt.plot(x.numpy(), y.numpy(), "o")
plt.ylabel("y")
plt.xlabel("x")

実行してみたら下記の図が出力されたと思います。

線形回帰のモデルを定義します。

class LR(nn.Module):
  def __init__(self, input_size, output_size):
    super().__init__()
    self.linear = nn.Linear(input_size, output_size)
  def forward(self, x):
    pred = self.linear(x)
    return pred

乱数のseedを固定します。線形回帰のモデルのインスタンスを作成します。

torch.manual_seed(1)
model = LR(1, 1)

モデルのパラメーターを取り出す関数を定義。

[w, b] = model.parameters()
def get_params():
  return (w[0][0].item(), b[0].item())

プロットする関数を定義。

def plot_fit(title):
  plt_title = title
  w1, b1 = get_params()
  x1 = np.array([-30, 30])
  y1 = w1*x1 + b1
  plt.plot(x1, y1, "r")
  plt.scatter(x, y)
  plt.show()

学習前の図をプロットしてみましょう。

plot_fit("initial Model")

損失関数は二乗平均誤差、学習方法は確率的勾配降下法に定義します。learning rateは0.01にしてます。

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

学習してみましょう。エポックは100にしてます。lossを記録しておきます。

epochs = 100
losses = []
for i in range(epochs):
  y_pred = model.forward(x)
  loss = criterion(y_pred, y)
  print("epoch:", i, "loss:", loss.item())
  
  losses.append(loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

学習経過を見てみましょう。

plt.plot(range(epochs), losses)
plt.ylabel("Loss")
plt.xlabel("epoch")

学習できているのがわかると思います。
学習後の図をプロットしてみましょう。

plot_fit("Trained Model")

ちゃんと学習できているのがわかると思います。

以上です。お疲れ様でした!

簡単にpytorchを使用して線形回帰を学習してみました。Pytorchと Google Colaboratory を使えばかなり簡単に機械学習を体験できるので、ぜひみなさんも試してみてください。

Sponsored Link

-deeplearning