Accel Brain; Console×

[備忘録] MXNetのNDArray APIで配列を縦に結合する

スポンサーリンク

問題設定:MXNetのNDArray APIにおける配列の結合

MXNetのNArray APIで配列を縦に結合しようとした。

問題解決策:mxnet.ndarray.concatとmxnet.ndarray.stackの区別

API文書のJoining and splitting arraysでは、`mxnet.ndarray.concat` と `mxnet.ndarray.stack` が紹介されていた。真っ先に目に付いたのはconcatの方であったが、系列データを縦に結合するだけであればstackの方が適しているように見える。

Examplesは以下のようになっている。

x = [1, 2]
y = [3, 4]

stack(x, y) = [[1, 2],
               [3, 4]]
stack(x, y, axis=1) = [[1, 3],
                       [2, 4]]

しかしこの通りに入力しても、 `mxnet.nd.array` の型ではないために、以下のようにエラーとなる。

>>> x = [1, 2]
>>> y = [3, 4]
>>> mx.ndarray.stack(x, y)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 39, in stack
AssertionError: Positional arguments must have NDArray type, but got [1, 2]

より視認し易く再記述するなら、以下のようになる。

>>> x_arr = mx.nd.array(x)
>>> y_arr = mx.nd.array(y)
>>> x_arr

[ 1.  2.]
<NDArray 2 @cpu(0)>
>>> y_arr

[ 3.  4.]
<NDArray 2 @cpu(0)>


>>> mx.ndarray.stack(x_arr, y_arr)

[[ 1.  2.]
 [ 3.  4.]]
<NDArray 2x2 @cpu(0)>

参考資料

MXNet – Python API — mxnet documentation (アクセス日時:2018/02/04 14:00)

フォローする