Numpy array のcopy で気をつけたいこと

TL; DR

  • Numpy array のcopy は気をつけたい
    • 特にNumpy array のブロードキャストの性質
  • Numpy array をcopy したい場合も、純粋なPython のList オブジェクトも、適切なcopy 関数をつかおう
    • でもスライス操作をするとその限りではなさそう?

Numpy array で気をつけたいことがあった

近頃、深層学習で当たり前のようにNumpy を触ることが増えてきていると思いますが、
Numpy の性質として気をつけたいところがありました。

なので、忘備録の意を込めて記事にしようかと思います。

Numpy 操作

以下のようにデータ操作を行いました。

>>> import numpy as np

>>> a = np.arange(0, 10)
>>> a  # array([0, 1, ..., 9])
>>> sub_a = a[:3]  # 最初の3要素をsub_a に代入。意図としてはa の中身をコピー
>>> sub_a  # array([0, 1, 2])

a の3要素を sub_a オブジェクトに代入しただけですね。
そして、この sub_a に値を変更しました。

>>> sub_a[1] = 10
>>> sub_a
array([0, 10, 2])

さて、このとき a の値はどうなっているでしょうか?
結果は、

>>> a  
array([0, 10, 2, ..., 9])

となります。

どうしてこうなるのか?

Numpy はメモリを効率的に扱うため、
見た目は異なるオブジェクトのように扱っていても、
変数の値は同じメモリを見るようになっているみたいです。

つまり、
a[0] の指しているメモリの番地
sub_a[0] が指しているメモリの番地が同じになっており、
片方を変更するともう一方も変更されるということになります。

特に気をつけたいブロードキャストの場合

Numpy では、以下のようにして、すべての要素にブロードキャストで代入することができます。

>>> sub_a[:] = 20

これはこれで大変便利ですよね。

このとき asub_a はどうなるでしょうか?

>>> sub_a
array([20, 20, 20])
>>> a
array([20, 20, ..., 9])

となってしまいます。

じゃあどうするのか?

このように別オブジェクトとしてコピーしたい場合は、
以下のように copy() メソッドを利用するのがよいです。

>>> a = np.arange(0, 10)
>>> sub_a = a[:3].copy()

>>> a
array([0 1 2 3 4 5 6 7 8 9])
>>> sub_a
array([0 1 2])

>>> sub_a[:] = 30
>>> sub_a
array([30 30 30])  # すべての要素が30になる。 
>>> a
array([0 1 2 3 4 5 6 7 8 9])  # 一方で、値が変わっていない。 

通常のPython のlist 型は大丈夫なの?

実はPython のList オブジェクトも似たような問題が知られています。
例えば、以下のような形式の場合。

>>> b = list(range(10)) # [0
>>> b  # [0, 1, 2, ..., 9]
>>> sub_b = b  # b をそのまま sub_b に代入
>>> b  # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> sub_b  # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> sub_b[0] = 10  # sub_b の要素を変更
>>> sub_b  # [10, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> b  # b の要素も変更されてしまっている。[10, 1, 2,..., 9]

上記は、ハマりポイントとして有名ですよね。
このように配列を別ものとしてコピーして操作したい場合は、
copy ライブラリを使って、オブジェクトをコピーする必要があります。

一次元の配列なら、 copy.copy()
多次元の配列なら、 copy.deepcopy() です。

ただし、一部であっても全体であっても、スライスでコピーした場合はその限りではないみたいです。

>>> b = list(range(10))
>>> sub_b = b[:3]
>>> sub_b[0] = 10
>>> b
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

>>> b = list(range(10))
>>> sub_b = b[:]
>>> sub_b[0] = 10
>>> sub_b
[10, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> b
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

是非参考にしてください。