問題設定:MXNetのNDArray APIにおける配列の結合
MXNetのNArray APIで配列を縦に結合しようとした。
問題解決策:mxnet.ndarray.concatとmxnet.ndarray.stackの区別
API文書のJoining and splitting arraysでは、`mxnet.ndarray.concat` と `mxnet.ndarray.stack` が紹介されていた。真っ先に目に付いたのはconcatの方であったが、系列データを縦に結合するだけであればstackの方が適しているように見える。
Examplesは以下のようになっている。
x = [1, 2] y = [3, 4] stack(x, y) = [[1, 2], [3, 4]] stack(x, y, axis=1) = [[1, 3], [2, 4]]
しかしこの通りに入力しても、 `mxnet.nd.array` の型ではないために、以下のようにエラーとなる。
>>> x = [1, 2] >>> y = [3, 4] >>> mx.ndarray.stack(x, y) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "<string>", line 39, in stack AssertionError: Positional arguments must have NDArray type, but got [1, 2]
より視認し易く再記述するなら、以下のようになる。
>>> x_arr = mx.nd.array(x) >>> y_arr = mx.nd.array(y) >>> x_arr [ 1. 2.] <NDArray 2 @cpu(0)> >>> y_arr [ 3. 4.] <NDArray 2 @cpu(0)> >>> mx.ndarray.stack(x_arr, y_arr) [[ 1. 2.] [ 3. 4.]] <NDArray 2x2 @cpu(0)>
参考資料
MXNet – Python API — mxnet documentation (アクセス日時:2018/02/04 14:00)