list型とTensor型のシャッフル方法
目次
list型のシャッフル方法
random.shuffle()を使う
random.shuffle()の引数にリストをセットすると、リストの内容がシャッフルされます。
破壊的メソッドなので引数に設定したリストが変更されます。
>>> x = [i for i in range(10)] >>> x [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] >>> random.shuffle(x) >>> x [9, 1, 6, 7, 2, 5, 0, 3, 4, 8]
random.sample()を使う
random.sample()の第一引数にリストを、第二引数に取得する要素数をセットすると、シャッフルされたリストが返り値として取得できます。
非破壊的メソッドなので引数に設定したリストが変更されません。
>>> x = [i for i in range(10)] >>> y = random.sample(x, 5) >>> y [9, 8, 4, 3, 5] >>> z = random.sample(x, len(x)) >>> z [4, 8, 0, 9, 6, 1, 3, 7, 2, 5] >>> x [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
PyTorchのTensor型のシャッフル
どの次元でシャッフルを行うかで変わってきます。torch.randperm()によってランダムな順番のインデックス配列を生成し、シャッフルに利用します。ただし、以下の実行例のうちdim=1の例のように、この方法ではある次元でのシャッフルのルールが各要素で同じになってしまうことに注意が必要です。
>>> x = torch.tensor([[[10 * k + 1 * j + 0.1 * i for i in range(3)] for j in range(4)] for k in range(3)]) >>> x.shape torch.Size([3, 4, 3]) >>> x tensor([[[ 0.0000, 0.1000, 0.2000], [ 1.0000, 1.1000, 1.2000], [ 2.0000, 2.1000, 2.2000], [ 3.0000, 3.1000, 3.2000]], [[10.0000, 10.1000, 10.2000], [11.0000, 11.1000, 11.2000], [12.0000, 12.1000, 12.2000], [13.0000, 13.1000, 13.2000]], [[20.0000, 20.1000, 20.2000], [21.0000, 21.1000, 21.2000], [22.0000, 22.1000, 22.2000], [23.0000, 23.1000, 23.2000]]]) >>> >>> # shuffle dim=0 ... x[torch.randperm(a.size()[0])] tensor([[[10.0000, 10.1000, 10.2000], [11.0000, 11.1000, 11.2000], [12.0000, 12.1000, 12.2000], [13.0000, 13.1000, 13.2000]], [[20.0000, 20.1000, 20.2000], [21.0000, 21.1000, 21.2000], [22.0000, 22.1000, 22.2000], [23.0000, 23.1000, 23.2000]], [[ 0.0000, 0.1000, 0.2000], [ 1.0000, 1.1000, 1.2000], [ 2.0000, 2.1000, 2.2000], [ 3.0000, 3.1000, 3.2000]]]) >>> >>> # shuffle dim=1 ... x[:, torch.randperm(a.size()[1])] tensor([[[ 2.0000, 2.1000, 2.2000], [ 1.0000, 1.1000, 1.2000], [ 3.0000, 3.1000, 3.2000], [ 0.0000, 0.1000, 0.2000]], [[12.0000, 12.1000, 12.2000], [11.0000, 11.1000, 11.2000], [13.0000, 13.1000, 13.2000], [10.0000, 10.1000, 10.2000]], [[22.0000, 22.1000, 22.2000], [21.0000, 21.1000, 21.2000], [23.0000, 23.1000, 23.2000], [20.0000, 20.1000, 20.2000]]])