[PyTorch] 텐서 합치기 (concat, stack)
Deep Learning (TF, Keras, PyTorch)/PyTorch basics 2023. 2. 21. 22:31이번 포스팅에서는 두 개의 PyTorch 텐서를 하나로 합치는 방법을 소개하겠습니다.
(1) torch.cat((x, y), dim=0), torch.concat(dim=0): 두 개의 텐서를 위-아래로 합치기 (vertically, row wise)
(2) torch.cat((x, y), dim=1), torch.concat(dim=1): 두 개의 텐서를 좌-우로 합치기 (horizontally, column wise)
(3) torch.vstack((x, y)), torch.row_stack(): 두 개의 텐서를 위-아래로 합치기 (vertically, row wise)
(4) torch.hstack((x, y)), torch.column_stack(): 두 개의 텐서를 좌-우로 합치기 (horizontally, column wise)
(5) torch.stack((x, y), dim=0): 두 개의 텐서를 새 차원(new dimension)으로 위-아래로 합치기
(6) torch.stack((x, y)dim=1): 두 개의 텐서를 새 차원(new dimension)으로 좌-우로 합치기
두 개의 텐서를 위-아래로 합치기 (vertically, row wise concatenation) 는 torch.cat(dim=0), torch.concat(dim=0), torch.vstack(), torch.row_stack() 로 수행 가능합니다.
두 개의 텐서를 좌-우로 합치기 (horizontally, column wise concatenation) 는 torch.cat(dim=1), torch.concat(dim=1), torch.hstack(), torch.column_stack() 로 수행 가능합니다.
torch.stack((x, y), dim=0) 은 새로운 차원(new dimension)을 추가해서 두 개의 텐서를 위-아래로 합쳐주며,
torch.stack((x, y), dim=1) 은 새로운 차원을 추가해서 두 개의 텐서를 좌-우로 합쳐주는 차이점이 있습니다.
예제로 사용할 두 개의 PyTorch 텐서를 만들어보겠습니다.
import torch
x = torch.arange(12).reshape(3, 4)
y = torch.arange(12, 24).reshape(3, 4)
print(x)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
print(y)
# tensor([[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]])
(1) torch.cat(dim=0), torch.concat(dim=0)
: 두 개의 텐서를 위-아래로 합치기 (vertically, row wise)
## torch.cat(): concatenating in the axis 0
torch.cat((x, y), dim=0) # or torch.concat((x, y), 0)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11],
# [12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]])
torch.cat((x, y), dim=0).shape
# torch.Size([6, 4])
(2) torch.cat((x, y), dim=1), torch.concat(dim=1)
: 두 개의 텐서를 좌-우로 합치기 (horizontally, column wise)
## concatenating in the axis 1
torch.cat((x, y), dim=1) # or torch.concat((x, y), 1)
# tensor([[ 0, 1, 2, 3, 12, 13, 14, 15],
# [ 4, 5, 6, 7, 16, 17, 18, 19],
# [ 8, 9, 10, 11, 20, 21, 22, 23]])
torch.cat((x, y), dim=1).shape
# torch.Size([3, 8])
(3) torch.vstack((x, y)), torch.row_stack()
: 두 개의 텐서를 위-아래로 합치기 (vertically, row wise)
## Stack tensors in sequence vertically (row wise).
torch.vstack((x, y))
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11],
# [12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]])
## or equivalently
torch.row_stack((x, y))
torch.row_stack((x, y)).shape
# torch.Size([6, 4])
(4) torch.hstack((x, y)), torch.column_stack()
: 두 개의 텐서를 좌-우로 합치기 (horizontally, column wise)
## Stack tensors in sequence horizontally (column wise).
torch.hstack((x, y))
# tensor([[ 0, 1, 2, 3, 12, 13, 14, 15],
# [ 4, 5, 6, 7, 16, 17, 18, 19],
# [ 8, 9, 10, 11, 20, 21, 22, 23]])
## or equivalently
torch.column_stack((x, y))
torch.column_stack((x, y)).shape
# torch.Size([3, 8])
(5) torch.stack((x, y), dim=0)
: 두 개의 텐서를 새 차원(new dimension)으로 위-아래로 합치기
위의 (1)~(4)번의 메소드 대비해서 torch.stack((x, y), dim=0) 은 새로운 차원(new dimension)을 추가해서 두 개의 텐서를 위-아래로 합쳐주며, torch.stack((x, y), dim=1) 은 새로운 차원을 추가해서 두 개의 텐서를 좌-우로 합쳐주는 차이점이 있습니다. (1)~(4)번의 텐서 합치기를 했을 때의 shape 과 (5)~(6)번의 텐서 합치기 후의 shape 을 유심히 비교해보시기 바랍니다.
## Concatenates a sequence of tensors along a new dimension.
## stack with axis=0 (vertically, row-wise)
torch.stack((x, y), dim=0) # axis=0
# tensor([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
## new dimension
torch.stack((x, y), dim=0).shape
# torch.Size([2, 3, 4])
(6) torch.stack((x, y)dim=1)
: 두 개의 텐서를 새 차원(new dimension)으로 좌-우로 합치기
## stack with axis=1 (horizontally, column-wise)
torch.stack((x, y), dim=1) # axis=1
# tensor([[[ 0, 1, 2, 3],
# [12, 13, 14, 15]],
# [[ 4, 5, 6, 7],
# [16, 17, 18, 19]],
# [[ 8, 9, 10, 11],
# [20, 21, 22, 23]]])
## new dimension
torch.stack((x, y), dim=1).shape
# torch.Size([3, 2, 4])
다음 포스팅에서는 이어서 '하나의 PyTorch 텐서를 복수개의 텐서로 나누기 (splitting a PyTorch tensor into multiple tensors)' 에 대해서 소개하겠습니다.
이번 포스팅이 많은 도움이 되었기를 바랍니다.
행복한 데이터 과학자 되세요! :-)
'Deep Learning (TF, Keras, PyTorch) > PyTorch basics' 카테고리의 다른 글
[PyTorch] torchvision.datasets 모듈에 내장되어 있는 데이터 가져와서 시각화하기 (0) | 2023.02.26 |
---|---|
[PyTorch] 텐서 나누기 (splitting a PyTorch tensor into multiple tensors) (0) | 2023.02.23 |
[PyTorch] 텐서의 인덱싱과 슬라이싱 (indexing & slicing of PyTorch tensor) (0) | 2023.02.19 |
[PyTorch] NumPy의 array 대비 PyTorch 의 성능 비교 (0) | 2023.02.19 |
[PyTorch] 난수를 생성해서 텐서 만들기 (generating a tensor with random numbers) (0) | 2023.02.12 |