import torch
tensor = torch.LongTensor([[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]])
print(tensor)
# tensor([[[ 1, 2, 3],
# [ 4, 5, 6]],
#
# [[ 7, 8, 9],
# [10, 11, 12]]])
# 1. Shape & dtype basics
print(tensor.shape)
# torch.Size([2, 2, 3])
# 2. Reshape family (no math change)
# reshape 这里有个使用点是,-1 可以自动计算,比如这里是 2 * 6 = 12;
print(tensor.reshape(-1, 6))
# tensor([[ 1, 2, 3, 4, 5, 6],
# [ 7, 8, 9, 10, 11, 12]])
print(tensor.reshape(4, 3))
# tensor([[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9],
# [10, 11, 12]])
# 3. Reorder dimensions
print(tensor.permute(0, 2, 1).shape)
# torch.Size([2, 3, 2])
print(tensor.permute(0, 2, 1))
# tensor([[[ 1, 4],
# [ 2, 5],
# [ 3, 6]],
#
# [[ 7, 10],
# [ 8, 11],
# [ 9, 12]]])
# 4. Add/remove size-1 dims
print(tensor.unsqueeze(0))
# tensor([[[[ 1, 2, 3],
# [ 4, 5, 6]],
#
# [[ 7, 8, 9],
# [10, 11, 12]]]])
print(tensor.squeeze(0))
# tensor([[[ 1, 2, 3],
# [ 4, 5, 6]],
#
# [[ 7, 8, 9],
# [10, 11, 12]]])
# 5. Broadcasting
tensor = torch.FloatTensor([[1,2,3],[4,5,6]])
mean = tensor.mean(dim=1)
print(mean)
# tensor([2.0000, 5.0000])
mean = mean.unsqueeze(1).expand_as(tensor)
print(mean)
# tensor([[2., 2., 2.],
# [5., 5., 5.]])
# 6. Indexing & slicing
tensor = torch.randn(2,3,4,5)
print(tensor.shape)
# torch.Size([2, 3, 4, 5])
y = tensor[:, 0]
print(y.shape)
# torch.Size([2, 4, 5])
y = tensor[..., -1]
print(y.shape)
# torch.Size([2, 3, 4])
# mask
tensor = torch.LongTensor([1,2,3,4,5,6])
y = torch.zeros_like(tensor)
y = torch.where(tensor > 3, tensor, y)
print(y)
# tensor([0, 0, 0, 4, 5, 6])
# 7. Combine / split
# cat
a = torch.randn(2,3)
b = torch.randn(2,3)
c = torch.cat([a,b], dim=0)
print(c.shape)
# torch.Size([4, 3])
c = torch.cat([a,b], dim=1)
print(c.shape)
# torch.Size([2, 6])
# stack (create a new dimension)
c = torch.stack([a,b], dim=0)
print(c.shape)
# torch.Size([2, 2, 3])
c = torch.stack([a,b,a], dim=1)
print(c.shape)
# torch.Size([2, 3, 3])
c = torch.stack([a,b], dim=2)
print(c.shape)
# torch.Size([2, 3, 2])
# expand vs repeat
# expand does not copy data, it only creates a new view of the same data
# repeat copies data
a = torch.randn(2,3)
b = torch.randn(2,7)
c = a.expand_as(b)
print(c.shape)
# torch.Size([2, 7])
c = a.repeat_interleave(2, dim=0)
print(c)
# tensor([[1, 2, 3],
# [1, 2, 3],
# [4, 5, 6],
# [4, 5, 6]])
# 8. Memory & performance (important!)
tensor = torch.LongTensor([[1,2,3],[4,5,6]])
# How many memory steps to move when increasing a dimension index by 1
print(tensor.stride())
# (3, 1)
# view does NOT move data in memory, it only reinterprets the existing memory layout
y = tensor.view(-1)
print(y)
print(y.stride())
# permute changes the order of dimensions and the stride
y = tensor.permute(1, 0)
print(y)
print(y.stride())
# (1, 3)
# after permute, the order of elements is changed in memory, we can't use view to reshape it
# do not execute y.view(-1), use y.contiguous() to get a contiguous tensor first
y = tensor.permute(1, 0).contiguous().view(-1)
print(y)
# tensor([1, 2, 3, 4, 5, 6])
PyTorch Tensor 变换
December 26, 2025