文章目录
- 1. einops
- 2. code
- 3. pytorch
1. einops
einops 主要是通过爱因斯坦标记法来处理张量矩阵的库,让矩阵处理上非常简单。
- conda :
python">conda install conda-forge::einops
2. code
python">import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
x = torch.arange(96).reshape((2, 3, 4, 4)).to(torch.float32)
print(f"x.shape={x.shape}")
print(f"x=\n{x}")
# 1. 转置
x_torch_trans = x.transpose(1, 2)
x_einops_trans = rearrange(x, 'b i w h -> b w i h')
x_check_trans = torch.allclose(x_torch_trans, x_einops_trans)
print(f"x_torch_trans is {x_check_trans} same with x_einops_trans")
# 2. 变形
x_torch_reshape = x.reshape(6, 4, 4)
x_einops_reshape = rearrange(x, 'b i w h -> (b i) w h')
x_check_reshape = torch.allclose(x_torch_reshape, x_einops_reshape)
print(f"x_einops_reshape is {x_check_reshape} same with x_check_reshape")
# 3. image2patch
image2patch = rearrange(x, 'b i (h1 p1) (w1 p2) -> b i (h1 w1) p1 p2', p1=2, p2=2)
print(f"image2patch.shape={image2patch.shape}")
print(f"image2patch=\n{image2patch}")
image2patch2 = rearrange(image2patch, 'b i j h w -> b (i j) h w')
print(f"image2patch2.shape={image2patch2.shape}")
print(f"image2patch2=\n{image2patch2}")
y = torch.arange(24).reshape((2, 3, 4)).to(torch.float32)
y_einops_mean = reduce(y, 'b h w -> b h', 'mean')
print(f"y=\n{y}")
print(f"y_einops_mean=\n{y_einops_mean}")
y_tensor = torch.arange(24).reshape(2, 2, 2, 3)
y_list = [y_tensor, y_tensor, y_tensor]
y_output = rearrange(y_list, 'n b i h w -> n b i h w')
print(f"y_tensor=\n{y_tensor}")
print(f"y_output=\n{y_output}")
z_tensor = torch.arange(12).reshape(2, 2, 3).to(torch.float32)
z_tensor_1 = rearrange(z_tensor, 'b h w -> b h w 1')
print(f"z_tensor=\n{z_tensor}")
print(f"z_tensor_1=\n{z_tensor_1}")
z_tensor_2 = repeat(z_tensor_1, 'b h w 1 -> b h w 2')
print(f"z_tensor_2=\n{z_tensor_2}")
z_tensor_repeat = repeat(z_tensor, 'b h w -> b (2 h) (2 w)')
print(f"z_tensor_repeat=\n{z_tensor_repeat}")
python">x.shape=torch.Size([2, 3, 4, 4])
x=
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]],
[[16., 17., 18., 19.],
[20., 21., 22., 23.],
[24., 25., 26., 27.],
[28., 29., 30., 31.]],
[[32., 33., 34., 35.],
[36., 37., 38., 39.],
[40., 41., 42., 43.],
[44., 45., 46., 47.]]],
[[[48., 49., 50., 51.],
[52., 53., 54., 55.],
[56., 57., 58., 59.],
[60., 61., 62., 63.]],
[[64., 65., 66., 67.],
[68., 69., 70., 71.],
[72., 73., 74., 75.],
[76., 77., 78., 79.]],
[[80., 81., 82., 83.],
[84., 85., 86., 87.],
[88., 89., 90., 91.],
[92., 93., 94., 95.]]]])
x_torch_trans is True same with x_einops_trans
x_einops_reshape is True same with x_check_reshape
image2patch.shape=torch.Size([2, 3, 4, 2, 2])
image2patch=
tensor([[[[[ 0., 1.],
[ 4., 5.]],
[[ 2., 3.],
[ 6., 7.]],
[[ 8., 9.],
[12., 13.]],
[[10., 11.],
[14., 15.]]],
[[[16., 17.],
[20., 21.]],
[[18., 19.],
[22., 23.]],
[[24., 25.],
[28., 29.]],
[[26., 27.],
[30., 31.]]],
[[[32., 33.],
[36., 37.]],
[[34., 35.],
[38., 39.]],
[[40., 41.],
[44., 45.]],
[[42., 43.],
[46., 47.]]]],
[[[[48., 49.],
[52., 53.]],
[[50., 51.],
[54., 55.]],
[[56., 57.],
[60., 61.]],
[[58., 59.],
[62., 63.]]],
[[[64., 65.],
[68., 69.]],
[[66., 67.],
[70., 71.]],
[[72., 73.],
[76., 77.]],
[[74., 75.],
[78., 79.]]],
[[[80., 81.],
[84., 85.]],
[[82., 83.],
[86., 87.]],
[[88., 89.],
[92., 93.]],
[[90., 91.],
[94., 95.]]]]])
image2patch2.shape=torch.Size([2, 12, 2, 2])
image2patch2=
tensor([[[[ 0., 1.],
[ 4., 5.]],
[[ 2., 3.],
[ 6., 7.]],
[[ 8., 9.],
[12., 13.]],
[[10., 11.],
[14., 15.]],
[[16., 17.],
[20., 21.]],
[[18., 19.],
[22., 23.]],
[[24., 25.],
[28., 29.]],
[[26., 27.],
[30., 31.]],
[[32., 33.],
[36., 37.]],
[[34., 35.],
[38., 39.]],
[[40., 41.],
[44., 45.]],
[[42., 43.],
[46., 47.]]],
[[[48., 49.],
[52., 53.]],
[[50., 51.],
[54., 55.]],
[[56., 57.],
[60., 61.]],
[[58., 59.],
[62., 63.]],
[[64., 65.],
[68., 69.]],
[[66., 67.],
[70., 71.]],
[[72., 73.],
[76., 77.]],
[[74., 75.],
[78., 79.]],
[[80., 81.],
[84., 85.]],
[[82., 83.],
[86., 87.]],
[[88., 89.],
[92., 93.]],
[[90., 91.],
[94., 95.]]]])
y=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]])
y_einops_mean=
tensor([[ 1.500, 5.500, 9.500],
[13.500, 17.500, 21.500]])
y_tensor=
tensor([[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]],
[[[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23]]]])
y_output=
tensor([[[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]],
[[[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23]]]],
[[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]],
[[[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23]]]],
[[[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]],
[[[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23]]]]])
z_tensor=
tensor([[[ 0., 1., 2.],
[ 3., 4., 5.]],
[[ 6., 7., 8.],
[ 9., 10., 11.]]])
z_tensor_1=
tensor([[[[ 0.],
[ 1.],
[ 2.]],
[[ 3.],
[ 4.],
[ 5.]]],
[[[ 6.],
[ 7.],
[ 8.]],
[[ 9.],
[10.],
[11.]]]])
z_tensor_2=
tensor([[[[ 0., 0.],
[ 1., 1.],
[ 2., 2.]],
[[ 3., 3.],
[ 4., 4.],
[ 5., 5.]]],
[[[ 6., 6.],
[ 7., 7.],
[ 8., 8.]],
[[ 9., 9.],
[10., 10.],
[11., 11.]]]])
z_tensor_repeat=
tensor([[[ 0., 1., 2., 0., 1., 2.],
[ 3., 4., 5., 3., 4., 5.],
[ 0., 1., 2., 0., 1., 2.],
[ 3., 4., 5., 3., 4., 5.]],
[[ 6., 7., 8., 6., 7., 8.],
[ 9., 10., 11., 9., 10., 11.],
[ 6., 7., 8., 6., 7., 8.],
[ 9., 10., 11., 9., 10., 11.]]])