list型とTensor型のシャッフル方法

目次

list型のシャッフル方法

random.shuffle()を使う

random.shuffle()の引数にリストをセットすると、リストの内容がシャッフルされます。
破壊的メソッドなので引数に設定したリストが変更されます。

>>> x = [i for i in range(10)]
>>> x
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> random.shuffle(x)
>>> x
[9, 1, 6, 7, 2, 5, 0, 3, 4, 8]

random.sample()を使う

random.sample()の第一引数にリストを、第二引数に取得する要素数をセットすると、シャッフルされたリストが返り値として取得できます。
非破壊的メソッドなので引数に設定したリストが変更されません。

>>> x = [i for i in range(10)]
>>> y = random.sample(x, 5)
>>> y
[9, 8, 4, 3, 5]
>>> z = random.sample(x, len(x))
>>> z
[4, 8, 0, 9, 6, 1, 3, 7, 2, 5]
>>> x
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

PyTorchのTensor型のシャッフル

どの次元でシャッフルを行うかで変わってきます。torch.randperm()によってランダムな順番のインデックス配列を生成し、シャッフルに利用します。ただし、以下の実行例のうちdim=1の例のように、この方法ではある次元でのシャッフルのルールが各要素で同じになってしまうことに注意が必要です。

>>> x = torch.tensor([[[10 * k + 1 * j + 0.1 * i for i in range(3)] for j in range(4)] for k in range(3)])
>>> x.shape
torch.Size([3, 4, 3])
>>> x
tensor([[[ 0.0000,  0.1000,  0.2000],
         [ 1.0000,  1.1000,  1.2000],
         [ 2.0000,  2.1000,  2.2000],
         [ 3.0000,  3.1000,  3.2000]],

        [[10.0000, 10.1000, 10.2000],
         [11.0000, 11.1000, 11.2000],
         [12.0000, 12.1000, 12.2000],
         [13.0000, 13.1000, 13.2000]],

        [[20.0000, 20.1000, 20.2000],
         [21.0000, 21.1000, 21.2000],
         [22.0000, 22.1000, 22.2000],
         [23.0000, 23.1000, 23.2000]]])
>>>
>>> # shuffle dim=0
... x[torch.randperm(a.size()[0])]
tensor([[[10.0000, 10.1000, 10.2000],
         [11.0000, 11.1000, 11.2000],
         [12.0000, 12.1000, 12.2000],
         [13.0000, 13.1000, 13.2000]],

        [[20.0000, 20.1000, 20.2000],
         [21.0000, 21.1000, 21.2000],
         [22.0000, 22.1000, 22.2000],
         [23.0000, 23.1000, 23.2000]],

        [[ 0.0000,  0.1000,  0.2000],
         [ 1.0000,  1.1000,  1.2000],
         [ 2.0000,  2.1000,  2.2000],
         [ 3.0000,  3.1000,  3.2000]]])
>>>
>>> # shuffle dim=1
... x[:, torch.randperm(a.size()[1])]
tensor([[[ 2.0000,  2.1000,  2.2000],
         [ 1.0000,  1.1000,  1.2000],
         [ 3.0000,  3.1000,  3.2000],
         [ 0.0000,  0.1000,  0.2000]],

        [[12.0000, 12.1000, 12.2000],
         [11.0000, 11.1000, 11.2000],
         [13.0000, 13.1000, 13.2000],
         [10.0000, 10.1000, 10.2000]],

        [[22.0000, 22.1000, 22.2000],
         [21.0000, 21.1000, 21.2000],
         [23.0000, 23.1000, 23.2000],
         [20.0000, 20.1000, 20.2000]]])