Pytorch里addmm()和addmm_()的用法详解

一、函数解释

torch/_C/_VariableFunctions.py的有该定义,意义就是实现一下公式:

out = \beta \times mat + \alpha \times \left ( mat1 @ mat2 \right )

换句话说,就是需要传入5个参数mat里的每个元素乘以betamat1mat2进行矩阵乘法左行乘右列)后再乘以alpha,最后将这2个结果加在一起。但是这样说可能没啥概念,接下来博主为大家写上一段代码,大家就明白了~


  
  1. def addmm(self, beta=1, mat, alpha=1, mat1, mat2, out=None): # real signature unknown; restored from __doc__
  2. """
  3. addmm(beta=1, mat, alpha=1, mat1, mat2, out=None) -> Tensor
  4. Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.
  5. The matrix :attr:`mat` is added to the final result.
  6. If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a
  7. :math:`(m \times p)` tensor, then :attr:`mat` must be
  8. :ref:`broadcastable <broadcasting-semantics>` with a :math:`(n \times p)` tensor
  9. and :attr:`out` will be a :math:`(n \times p)` tensor.
  10. :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between
  11. :attr:`mat1` and :attr`mat2` and the added matrix :attr:`mat` respectively.
  12. .. math::
  13. out = \beta\ mat + \alpha\ (mat1_i \mathbin{@} mat2_i)
  14. For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and
  15. :attr:`alpha` must be real numbers, otherwise they should be integers.
  16. Args:
  17. beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
  18. mat (Tensor): matrix to be added
  19. alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
  20. mat1 (Tensor): the first matrix to be multiplied
  21. mat2 (Tensor): the second matrix to be multiplied
  22. out (Tensor, optional): the output tensor
  23. Example::
  24. >>> M = torch.randn(2, 3)
  25. >>> mat1 = torch.randn(2, 3)
  26. >>> mat2 = torch.randn(3, 3)
  27. >>> torch.addmm(M, mat1, mat2)
  28. tensor([[-4.8716, 1.4671, -1.3746],
  29. [ 0.7573, -3.9555, -2.8681]])
  30. """
  31. pass

二、代码范例

1.先摆出代码,大家可以先复制粘贴运行一下,在之后博主会一一讲解


  
  1. """
  2. @author:nickhuang1996
  3. """
  4. import torch
  5. rectangle_height = 3
  6. rectangle_width = 3
  7. inputs = torch.randn(rectangle_height, rectangle_width)
  8. for i in range(rectangle_height):
  9. for j in range(rectangle_width):
  10. inputs[i] = i * torch.ones(rectangle_width)
  11. '''
  12. inputs and its transpose
  13. -->inputs = tensor([[0., 0., 0.],
  14. [1., 1., 1.],
  15. [2., 2., 2.]])
  16. -->inputs_t = tensor([[0., 1., 2.],
  17. [0., 1., 2.],
  18. [0., 1., 2.]])
  19. '''
  20. print("inputs:\n", inputs)
  21. inputs_t = inputs.t()
  22. print("inputs_t:\n", inputs_t)
  23. '''
  24. inputs_t @ inputs_t [[0., 1., 2.], [[0., 1., 2.], [[0., 3., 6.]
  25. = [0., 1., 2.], @ [0., 1., 2.], = [0., 3., 6.]
  26. [0., 1., 2.]] [0., 1., 2.]] [0., 3., 6.]]
  27. '''
  28. '''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
  29. a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
  30. b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
  31. c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
  32. d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
  33. '''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
  34. e = torch.addmm(inputs, inputs_t, inputs_t)
  35. f = inputs.addmm(inputs_t, inputs_t)
  36. '''1 * inputs + 1 * (inputs_t @ inputs_t)'''
  37. g = inputs.addmm(1, inputs_t, inputs_t)
  38. '''2 * inputs + 1 * (inputs_t @ inputs_t)'''
  39. g2 = inputs.addmm(2, inputs_t, inputs_t)
  40. '''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
  41. h = inputs.addmm(1, 1, inputs_t, inputs_t)
  42. '''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
  43. h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
  44. '''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
  45. h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
  46. print("a:\n", a)
  47. print("b:\n", b)
  48. print("c:\n", c)
  49. print("d:\n", d)
  50. print("e:\n", e)
  51. print("f:\n", f)
  52. print("g:\n", g)
  53. print("g2:\n", g2)
  54. print("h:\n", h)
  55. print("h12:\n", h12)
  56. print("h21:\n", h21)
  57. print("inputs:\n", inputs)
  58. '''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
  59. '''
  60. inputs @ inputs_t [[0., 0., 0.], [[0., 1., 2.], [[0., 0., 0.]
  61. = [1., 1., 1.], @ [0., 1., 2.], = [0., 3., 6.]
  62. [2., 2., 2.]] [0., 1., 2.]] [0., 6., 12.]]
  63. '''
  64. inputs.addmm_(1, -2, inputs, inputs_t) # In-place
  65. print("inputs:\n", inputs)

2.其中

inputs是一个3×3的矩阵,为


  
  1. tensor([[0., 0., 0.],
  2. [1., 1., 1.],
  3. [2., 2., 2.]])

inputs_t也是一个3×3的矩阵,是inputs转置矩阵,为


  
  1. tensor([[0., 1., 2.],
  2. [0., 1., 2.],
  3. [0., 1., 2.]])

inputs_t @ inputs_t


  
  1. '''
  2. inputs_t @ inputs_t [[0., 1., 2.], [[0., 1., 2.], [[0., 3., 6.]
  3. = [0., 1., 2.], @ [0., 1., 2.], = [0., 3., 6.]
  4. [0., 1., 2.]] [0., 1., 2.]] [0., 3., 6.]]
  5. '''

3.代码中abcd展示的是完全形式,即标明了位置参数传入参数。可以看到input这个位置参数可以写在函数的前面,即

torch.addmm(input, mat1, mat2) = inputs.addmm(mat1, mat2)
 

完成的公式为:

1 × inputs + 1 ×(inputs_t @ inputs_t)


  
  1. '''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
  2. a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
  3. b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
  4. c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
  5. d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)

  
  1. a:
  2. tensor([[0., 3., 6.],
  3. [1., 4., 7.],
  4. [2., 5., 8.]])
  5. b:
  6. tensor([[0., 3., 6.],
  7. [1., 4., 7.],
  8. [2., 5., 8.]])
  9. c:
  10. tensor([[0., 3., 6.],
  11. [1., 4., 7.],
  12. [2., 5., 8.]])
  13. d:
  14. tensor([[0., 3., 6.],
  15. [1., 4., 7.],
  16. [2., 5., 8.]])

4.下面的例子更好了说明了input参数的位置可变性,并且betaalpha缺省了:

完成的公式为:

1 × inputs + 1 ×(inputs_t @ inputs_t)


  
  1. '''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
  2. e = torch.addmm(inputs, inputs_t, inputs_t)
  3. f = inputs.addmm(inputs_t, inputs_t)

  
  1. e:
  2. tensor([[0., 3., 6.],
  3. [1., 4., 7.],
  4. [2., 5., 8.]])
  5. f:
  6. tensor([[0., 3., 6.],
  7. [1., 4., 7.],
  8. [2., 5., 8.]])

5.加一个参数,实际上是添加了beta这个参数

完成的公式为:

g   = 1 × inputs + 1 ×(inputs_t @ inputs_t)

g2 = 2 × inputs + 1 ×(inputs_t @ inputs_t)


  
  1. '''1 * inputs + 1 * (inputs_t @ inputs_t)'''
  2. g = inputs.addmm(1, inputs_t, inputs_t)
  3. '''2 * inputs + 1 * (inputs_t @ inputs_t)'''
  4. g2 = inputs.addmm(2, inputs_t, inputs_t)

  
  1. g:
  2. tensor([[0., 3., 6.],
  3. [1., 4., 7.],
  4. [2., 5., 8.]])
  5. g2:
  6. tensor([[ 0., 3., 6.],
  7. [ 2., 5., 8.],
  8. [ 4., 7., 10.]])

6.再加一个参数,实际上是添加了alpha这个参数

完成的公式为:

h   = 1 × inputs + 1 ×(inputs_t @ inputs_t)

h12 = 1 × inputs + 2 ×(inputs_t @ inputs_t)

h21 = 2 × inputs + 1 ×(inputs_t @ inputs_t)


  
  1. '''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
  2. h = inputs.addmm(1, 1, inputs_t, inputs_t)
  3. '''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
  4. h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
  5. '''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
  6. h21 = inputs.addmm(2, 1, inputs_t, inputs_t)

  
  1. h:
  2. tensor([[0., 3., 6.],
  3. [1., 4., 7.],
  4. [2., 5., 8.]])
  5. h12:
  6. tensor([[ 0., 6., 12.],
  7. [ 1., 7., 13.],
  8. [ 2., 8., 14.]])
  9. h21:
  10. tensor([[ 0., 3., 6.],
  11. [ 2., 5., 8.],
  12. [ 4., 7., 10.]])

7.当然,以上的步骤inputs没有变化,还是为


  
  1. inputs:
  2. tensor([[0., 0., 0.],
  3. [1., 1., 1.],
  4. [2., 2., 2.]])

*8.addmm_()的操作和addmm()函数功能相同,区别就是addmm_()inplace的操作,也就是在原对象基础上进行修改,即把改变之后的变量再赋给原来的变量。例如:

inputs的值变成了改变之后的值,不用再去写 某个变量=addmm_() 了,因为inputs就是改变之后的变量

*inputs@ inputs_t


  
  1. '''
  2. inputs @ inputs_t [[0., 0., 0.], [[0., 1., 2.], [[0., 0., 0.]
  3. = [1., 1., 1.], @ [0., 1., 2.], = [0., 3., 6.]
  4. [2., 2., 2.]] [0., 1., 2.]] [0., 6., 12.]]
  5. '''

完成的公式为:

inputs   = 1 × inputs - 2 ×(inputs @ inputs_t)


  
  1. '''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
  2. inputs.addmm_(1, -2, inputs, inputs_t) # In-place

  
  1. inputs:
  2. tensor([[ 0., 0., 0.],
  3. [ 1., -5., -11.],
  4. [ 2., -10., -22.]])

三、代码运行结果


  
  1. inputs:
  2. tensor([[0., 0., 0.],
  3. [1., 1., 1.],
  4. [2., 2., 2.]])
  5. inputs_t:
  6. tensor([[0., 1., 2.],
  7. [0., 1., 2.],
  8. [0., 1., 2.]])
  9. a:
  10. tensor([[0., 3., 6.],
  11. [1., 4., 7.],
  12. [2., 5., 8.]])
  13. b:
  14. tensor([[0., 3., 6.],
  15. [1., 4., 7.],
  16. [2., 5., 8.]])
  17. c:
  18. tensor([[0., 3., 6.],
  19. [1., 4., 7.],
  20. [2., 5., 8.]])
  21. d:
  22. tensor([[0., 3., 6.],
  23. [1., 4., 7.],
  24. [2., 5., 8.]])
  25. e:
  26. tensor([[0., 3., 6.],
  27. [1., 4., 7.],
  28. [2., 5., 8.]])
  29. f:
  30. tensor([[0., 3., 6.],
  31. [1., 4., 7.],
  32. [2., 5., 8.]])
  33. g:
  34. tensor([[0., 3., 6.],
  35. [1., 4., 7.],
  36. [2., 5., 8.]])
  37. g2:
  38. tensor([[ 0., 3., 6.],
  39. [ 2., 5., 8.],
  40. [ 4., 7., 10.]])
  41. h:
  42. tensor([[0., 3., 6.],
  43. [1., 4., 7.],
  44. [2., 5., 8.]])
  45. h12:
  46. tensor([[ 0., 6., 12.],
  47. [ 1., 7., 13.],
  48. [ 2., 8., 14.]])
  49. h21:
  50. tensor([[ 0., 3., 6.],
  51. [ 2., 5., 8.],
  52. [ 4., 7., 10.]])
  53. inputs:
  54. tensor([[0., 0., 0.],
  55. [1., 1., 1.],
  56. [2., 2., 2.]])
  57. inputs:
  58. tensor([[ 0., 0., 0.],
  59. [ 1., -5., -11.],
  60. [ 2., -10., -22.]])

文章来源: nickhuang1996.blog.csdn.net,作者:悲恋花丶无心之人,版权归原作者所有,如需转载,请联系作者。

原文链接:nickhuang1996.blog.csdn.net/article/details/90638449

(完)