基础运算

  • 可以使用 + - * / 四则运算符号(推荐)
  • 也可以使用 torch.add, torch.mul, torch.sub, torch.div

加法运算

  1. def add():
  2. # add +
  3. # 这两个Tensor加减乘除会对b自动进行Broadcasting
  4. a = torch.rand(3,4)
  5. b = torch.rand(4)
  6. print("a = {}".format(a))
  7. print("b = {}".format(b))
  8. # a、b列数相同,行数不同,将a的每行与b对应位置相加
  9. c1 = a + b
  10. c2 = torch.add(a,b)
  11. c3 = torch.eq(c1,c2)
  12. # torch.all()判断每个位置的元素是否相同
  13. c4 = torch.all(c3)
  14. print("a + b = {}".format(c1))
  15. print("a + b = {}".format(c2))
  16. print("torch.eq = {}".format(c3))
  17. print("torch all = {}".format(c4))
  18. # a = tensor([[0.8514, 0.5017, 0.3924, 0.7817],
  19. # [0.0219, 0.7352, 0.5634, 0.7285],
  20. # [0.9187, 0.1628, 0.9236, 0.3603]])
  21. # b = tensor([0.0809, 0.0295, 0.6065, 0.8024])
  22. # a + b = tensor([[0.9322, 0.5312, 0.9989, 1.5841],
  23. # [0.1028, 0.7647, 1.1700, 1.5309],
  24. # [0.9996, 0.1923, 1.5301, 1.1627]])
  25. # a + b = tensor([[0.9322, 0.5312, 0.9989, 1.5841],
  26. # [0.1028, 0.7647, 1.1700, 1.5309],
  27. # [0.9996, 0.1923, 1.5301, 1.1627]])
  28. # torch.eq = tensor([[True, True, True, True],
  29. # [True, True, True, True],
  30. # [True, True, True, True]])
  31. # torch
  32. # all = True

减法运算

  1. def minus():
  2. # 这两个Tensor加减乘除会对b自动进行Broadcasting
  3. a = torch.rand(3,4)
  4. b = torch.rand(4)
  5. print("a = {}".format(a))
  6. print("b = {}".format(b))
  7. # a、b列数相同,行数不同,将a的每行与b对应位置相加
  8. c1 = a - b
  9. c2 = torch.sub(a,b)
  10. # torch.all()判断每个位置的元素是否相同
  11. c3 = torch.eq(c1,c2)
  12. c4 = torch.all(c3)
  13. print("a - b = {}".format(c1))
  14. print("a - b = {}".format(c2))
  15. print("torch.eq = {}".format(c3))
  16. print("torch all = {}".format(c4))
  17. # a = tensor([[0.8499, 0.1003, 0.3179, 0.1217],
  18. # [0.2119, 0.7742, 0.3973, 0.7241],
  19. # [0.8559, 0.3558, 0.1549, 0.4583]])
  20. # b = tensor([0.4750, 0.9261, 0.7107, 0.1397])
  21. # a - b = tensor([[0.3749, -0.8258, -0.3928, -0.0180],
  22. # [-0.2631, -0.1519, -0.3135, 0.5844],
  23. # [0.3809, -0.5703, -0.5558, 0.3186]])
  24. # a - b = tensor([[0.3749, -0.8258, -0.3928, -0.0180],
  25. # [-0.2631, -0.1519, -0.3135, 0.5844],
  26. # [0.3809, -0.5703, -0.5558, 0.3186]])
  27. # torch.eq = tensor([[True, True, True, True],
  28. # [True, True, True, True],
  29. # [True, True, True, True]])
  30. # torch
  31. # all = True

哈达玛积 (element wise,对应元素相乘)

  1. def mul_element():
  2. # 这两个Tensor加减乘除会对b自动进行Broadcasting
  3. a = torch.rand(3,4)
  4. b = torch.rand(4)
  5. print("a = {}".format(a))
  6. print("b = {}".format(b))
  7. # a、b列数相同,行数不同,将a的每行与b对应位置相加
  8. c1 = a * b
  9. c2 = torch.mul(a,b)
  10. # torch.all()判断每个位置的元素是否相同
  11. c3 = torch.eq(c1,c2)
  12. c4 = torch.all(c3)
  13. print("a * b = {}".format(c1))
  14. print("a * b = {}".format(c2))
  15. print("torch.eq = {}".format(c3))
  16. print("torch all = {}".format(c4))
  17. # a = tensor([[0.9678, 0.8896, 0.5657, 0.7644],
  18. # [0.0581, 0.3479, 0.2008, 0.1259],
  19. # [0.4169, 0.9426, 0.1330, 0.5813]])
  20. # b = tensor([0.3827, 0.7139, 0.4547, 0.6798])
  21. # a * b = tensor([[0.3704, 0.6351, 0.2572, 0.5197],
  22. # [0.0222, 0.2484, 0.0913, 0.0856],
  23. # [0.1595, 0.6729, 0.0605, 0.3952]])
  24. # a * b = tensor([[0.3704, 0.6351, 0.2572, 0.5197],
  25. # [0.0222, 0.2484, 0.0913, 0.0856],
  26. # [0.1595, 0.6729, 0.0605, 0.3952]])
  27. # torch.eq = tensor([[True, True, True, True],
  28. # [True, True, True, True],
  29. # [True, True, True, True]])
  30. # torch all = True

除法运算

对应元素相除

  1. def test():
  2. # 这两个Tensor加减乘除会对b自动进行Broadcasting
  3. a = torch.rand(3,4)
  4. b = torch.rand(4)
  5. print("a = {}".format(a))
  6. print("b = {}".format(b))
  7. # a、b列数相同,行数不同,将a的每行与b对应位置相加
  8. c1 = a / b
  9. c2 = torch.div(a,b)
  10. # torch.all()判断每个位置的元素是否相同
  11. c3 = torch.eq(c1,c2)
  12. c4 = torch.all(c3)
  13. print("a / b = {}".format(c1))
  14. print("a / b = {}".format(c2))
  15. print("torch.eq = {}".format(c3))
  16. print("torch all = {}".format(c4))
  17. #a = tensor([[0.6079, 0.2791, 0.0034, 0.6169],
  18. # [0.5279, 0.7804, 0.5960, 0.0359],
  19. # [0.3385, 0.2300, 0.2021, 0.7161]])
  20. # b = tensor([0.5951, 0.8573, 0.7276, 0.8717])
  21. # a * b = tensor([[1.0214, 0.3256, 0.0047, 0.7077],
  22. # [0.8870, 0.9103, 0.8190, 0.0412],
  23. # [0.5687, 0.2682, 0.2778, 0.8215]])
  24. # a * b = tensor([[1.0214, 0.3256, 0.0047, 0.7077],
  25. # [0.8870, 0.9103, 0.8190, 0.0412],
  26. # [0.5687, 0.2682, 0.2778, 0.8215]])
  27. # torch.eq = tensor([[True, True, True, True],
  28. # [True, True, True, True],
  29. # [True, True, True, True]])
  30. # torch all = True

矩阵运算

  • matmul 表示 matrix mul
  • * 表示的是 element-wise, 对应元素相乘
  • torch.mm(a,b) 只能计算 2D 不推荐,矩阵相乘
  • torch.matmul(a,b) 可以计算更高维度,落脚点依旧在行与列。 推荐
  • @ 是 matmul 的重载形式

二维矩阵相乘

二维矩阵乘法运算操作包括 torch.mm()、torch.matmul()、@

  1. def test():
  2. a = torch.ones(2,1)
  3. b = torch.ones(1,2)
  4. print("a = {}".format(a))
  5. print("b = {}".format(b))
  6. c1 = torch.mm(a,b)
  7. c2 = torch.matmul(a,b)
  8. c3 = a @ b
  9. print("c1 = {}".format(c1))
  10. print("c2 = {}".format(c2))
  11. print("c3 = {}".format(c3))
  12. # a = tensor([[1.],
  13. # [1.]])
  14. # b = tensor([[1., 1.]])
  15. # c1 = tensor([[1., 1.],
  16. # [1., 1.]])
  17. # c2 = tensor([[1., 1.],
  18. # [1., 1.]])
  19. # c3 = tensor([[1., 1.],
  20. # [1., 1.]])

多维矩阵相乘

对于高维的 Tensor(dim>2),定义其矩阵乘法仅在最后的两个维度上,要求前面的维度必须保持一致,就像矩阵的索引一样并且运算操只有 torch.matmul()。

  • 对于 2 维以上的 matrix multiply , torch.mm(a,b)就不行了。
  • 运算规则:只取最后的两维做矩阵乘法
  • 对于 [b, c, h, w] 来说,b,c 是不变的,图片的大小在改变;并且也并行的计算出了 b,c。也就是支持多个矩阵并行相乘
  • 对于不同的 size,如果符合 broadcast,先执行 broadcast,在进行矩阵相乘。
  1. def test():
  2. # 多维矩阵计算,前两个维度必须一致
  3. c = torch.rand(4, 3, 28, 64)
  4. d = torch.rand(4, 3, 64, 32)
  5. print(torch.matmul(c,d).shape)
  6. # torch.Size([4, 3, 28, 32])

注意,在这种情形下的矩阵相乘,前面的 “矩阵索引维度” 如果符合 Broadcasting 机制,也会自动做广播,然后相乘。

  1. def test():
  2. # 多维矩阵计算,前两个维度必须一致
  3. c = torch.rand(4, 3, 28, 64)
  4. d = torch.rand(4, 1, 64, 32)
  5. print(torch.matmul(c,d).shape)
  6. # torch.Size([4, 3, 28, 32])

幂运算

  1. def test():
  2. # troch.full(size, fill_value)
  3. # 参数:
  4. # size: 生成张量的大小,list, tuple, torch.size
  5. # fill_value: 填充张量的数
  6. a = torch.full([2, 2], 3)
  7. print("a = {}".format(a))
  8. b1 = a.pow(2) # 也可以a**2
  9. b2 = a**2
  10. print("b1 = {}".format(b1))
  11. print("b2 = {}".format(b2))
  12. # a = tensor([[3., 3.],
  13. # [3., 3.]])
  14. # b1 = tensor([[9., 9.],
  15. # [9., 9.]])
  16. # b2 = tensor([[9., 9.],
  17. # [9., 9.]])
  18. #

开方运算

  • pow(a, n) : a 的 n 次方
  • ** 也表示次方(可以是 2,0.5,0.25,3) 推荐
  • sqrt() 表示 square root 平方根
  • rsqrt() 表示平方根的倒数
  1. def test():
  2. a = torch.full([2, 2], 9)
  3. print("a = {}".format(a))
  4. b1 = a.sqrt() # 也可以a**(0.5)
  5. # 平方根的倒数
  6. b2 = a.rsqrt()
  7. print("b1 = {}".format(b1))
  8. print("b2 = {}".format(b2))
  9. # a = tensor([[9., 9.],
  10. # [9., 9.]])
  11. # b1 = tensor([[3., 3.],
  12. # [3., 3.]])
  13. # b2 = tensor([[0.3333, 0.3333],
  14. # [0.3333, 0.3333]])

指数与对数运算

注意log是以自然对数为底数的,以 2 为底的用log2,以 10 为底的用log10

  • exp(n) 表示:e 的 n 次方
  • log(a) 表示:ln(a)
  • log2() 、 log10()
  1. def test():
  2. a = torch.ones(2,2)
  3. print("a = {}".format(a))
  4. # 得到 2*2 矩阵的全是 e 的Tensor,相当于a的所有元素乘以e
  5. b = torch.exp(a)
  6. c = torch.log(a)
  7. print("b = {}".format(b))
  8. print("c = {}".format(c))
  9. # a = tensor([[1., 1.],
  10. # [1., 1.]])
  11. # b = tensor([[2.7183, 2.7183],
  12. # [2.7183, 2.7183]])
  13. # c = tensor([[0., 0.],
  14. # [0., 0.]])

近似值运算

近似相关 1

  • floor、ceil 向下取整、向上取整
  • round 4 舍 5 入
  • trunc、frac 裁剪
  1. def test():
  2. a = torch.tensor(3.14)
  3. b = torch.tensor(3.49)
  4. c = torch.tensor(3.5)
  5. # 取下,取上,取整数,取小数
  6. print("a.floor = {},a.ceil = {},a.trunc = {},a.frac = {}"
  7. .format(a.floor(),a.ceil(),a.trunc(),a.frac()))
  8. # 四舍五入
  9. print("b.rounc = {}, c.round = {}".format(b.round(),c.round()))
  10. # a.floor = 3.0, a.ceil = 4.0, a.trunc = 3.0, a.frac = 0.1400001049041748
  11. # b.rounc = 3.0, c.round = 4.0

裁剪运算

即对 Tensor 中的元素进行范围过滤,不符合条件的可以把它变换到范围内部(边界)上,常用于梯度裁剪(gradient clipping),即在发生梯度离散或者梯度爆炸时对梯度的处理,实际使用时可以查看梯度的(L2 范数)模来看看需不需要做处理:w.grad.norm(2)

近似相关 2 (用的更多一些)

  • gradient clipping 梯度裁剪
  • (min) 小于 min 的都变为某某值
  • (min, max) 不在这个区间的都变为某某值
  • 梯度爆炸:一般来说,当梯度达到 100 左右的时候,就已经很大了,正常在 10 左右,通过打印梯度的模来查看 w.grad.norm(2)
  • 对于 w 的限制叫做 weight clipping,对于 weight gradient clipping 称为 gradient clipping。
  1. def test():
  2. # 两行三列,切元素在0-15之间随机生成
  3. grad = torch.rand(2, 3) * 15 # 0~15随机生成
  4. print("grad = {}".format(grad))
  5. # 最大值最小值平均值
  6. print("grad.max = {}, grad.min = {}, grad.median = {}"
  7. .format(grad.max(), grad.min(), grad.median()))
  8. # 最小是10,小于10的都变成10
  9. print("grad.clamp(10) = {}".format(grad.clamp(10)))
  10. # 最小是3, 小于3的都变成3; 最大是10, 大于10的都变成10
  11. print("grad.clamp(3, 10) = {}".format(grad.clamp(3, 10))) #
  12. # grad = tensor([[7.2015, 13.5902, 3.7276],
  13. # [3.9825, 2.9701, 11.7545]])
  14. # grad.max = 13.590229034423828, grad.min = 2.9700870513916016, grad.median = 3.982494831085205
  15. # grad.clamp(10) = tensor([[10.0000, 13.5902, 10.0000],
  16. # [10.0000, 10.0000, 11.7545]])
  17. # grad.clamp(3, 10) = tensor([[7.2015, 10.0000, 3.7276],
  18. # [3.9825, 3.0000, 10.0000]])

References