前項ではデータのスライシングについて見てきました。そしてスライスしたデータは基本的にリシェイプする必要があります。
たとえば scikit-learn では、出力変数 \(y\) として使用するためにスライスした 1 行 \(n\) 列の配列は、\(n\) 行 1 列の配列にリシェイプしてから使用する必要があります。また、Keras のいくつかのアルゴリズムは、データを標本(samples)・タイムステップ(timestep)・特徴(features)で構成される 3 次元配列にして扱う必要があります。
そこで、このページでは NumPy の配列を、それぞれのライブラリに沿って必要な形状にリシェイプする方法を解説します。
当ページでわかること
- 機械学習アルゴリズムに合うようにデータをリシェイプする方法
データの形状(shape)について
配列のリシェイプについて解説する前に、配列の shape
属性についておさらしておきましょう。この属性は配列の行数と列数を、(行数, 列数)
というようにタプルとして格納しています。以下のコードをご確認ください。
まず 1 次元配列では、shape
属性には列数のみが格納されています。
# NumPyのインポート
import numpy as np
# 行列を作成
data = np.array([1,2,3,4,5])
# shape属性の確認
print(data.shape)
2 次元配列では、shape
属性には行数と列数が格納されています。
# 2次元行列を作成
data = np.array([
[1,2],
[3,4],
[5,6]])
# shape属性の確認
print(data.shape)
以下のように書くとわかりやすいでしょう。
# 2次元行列を作成
data = np.array([
[1,2],
[3,4],
[5,6]])
# shape属性の確認
print('行: %d' % data.shape[0])
print('列: %d' % data.shape[1])
この shape
属性は、配列をリシェイプする際によく確認することになりますので、あらためて頭に入れておきましょう。
データのリシェイプ
それでは、ここからデータをリシェイプする方法を解説します。特に実務上必ず行うことになる、以下の 2 つのリシェイプ方法を解説します。
- 1 行 \(n\) 列の配列を、\(n\) 行 1 列の配列にリシェイプする
- 2 次元配列を 3 次元配列にリシェイプする
なお配列のリシェイプには reshape()
メソッドを使います。それでは見ていきましょう。
1 行 \(n\) 列の配列を、\(n\) 行 1 列の配列にリシェイプする
これは出力変数 \(y\) としてスライスしたベクトルデータを、scikit-learn などの機械学習アルゴリズムに投入するときなどのよく行う操作です。これは reshape()
メソッドを使って、data.reshape((data.shape[0],1))
と書きます。
実際のコードを見てみましょう。
# NumPyのインポート
import numpy as np
# 1次元配列を作成
data = np.array([1,2,3,4,5])
print(data)
print(data.shape)
これを次のように書くと \(n\) 行 1 列の 2 次元配列に変換されます。
# 1次元配列をn行1列にリシェイプ
data = data.reshape((data.shape[0],1))
print(data)
print(data.shape)
2次元データを3次元データに変換する
2 次元データを、1 つ以上のタイムステップと 1 つ以上の特徴を必要とするアルゴリズムで使用できるようにするために、3 次元データにリシェイプすることも、よく行われる配列のリシェイプ操作です。このような操作を必要とする代表例としては Kerasのディープラーニング・ライブラリの LSTM recurrent neural network model があります。
この場合は、data.reshape((data.shape[0], data.shape[1], 1))
というように書きます。
具体例を見てみましょう。まず2次元データを作成します。
# NumPyのインポート
import numpy as np
# 2次元データを作成
data = np.array([
[1,2],
[3,4],
[5,6]])
print(data)
print("\n")
print(data.shape)
これを3 次元データに変換するには次のように書きます。
# 3次元データにリシェイプ
data = data.reshape((data.shape[0], data.shape[1], 1))
print(data)
print("\n")
print(data.shape)
なお、それぞれの操作を詳しく理解したい場合は、ぜひ以下の一緒に読んでおきたいページを参考にしてください。
一緒に読んでおきたいページ