乡下人产国偷v产偷v自拍,国产午夜片在线观看,婷婷成人亚洲综合国产麻豆,久久综合给合久久狠狠狠9

  • <output id="e9wm2"></output>
    <s id="e9wm2"><nobr id="e9wm2"><ins id="e9wm2"></ins></nobr></s>

    • 分享

      畫pytorch模型圖,以及參數(shù)計算

       LibraryPKU 2019-09-30

          剛?cè)雙ytorch的坑,代碼還沒看太懂。之前用keras用習(xí)慣了,第一次使用pytorch還有些不適應(yīng),希望廣大老司機多多指教。

          首先說說,我們?nèi)绾慰梢暬P?。在keras中就一句話,keras.summary(),或者plot_model(),就可以把模型展現(xiàn)的淋漓盡致。

      但是pytorch中好像沒有這樣一個api讓我們直觀的看到模型的樣子。但是有網(wǎng)友提供了一段代碼,可以把模型畫出來,對我來說簡直就是如有神助啊。話不多說,上代碼吧。

      1. import torch
      2. from torch.autograd import Variable
      3. import torch.nn as nn
      4. from graphviz import Digraph
      5. class CNN(nn.Module):
      6. def __init__(self):
      7. super(CNN, self).__init__()
      8. self.conv1 = nn.Sequential(
      9. nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      10. nn.ReLU(),
      11. nn.MaxPool2d(kernel_size=2)
      12. )
      13. self.conv2 = nn.Sequential(
      14. nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      15. nn.ReLU(),
      16. nn.MaxPool2d(kernel_size=2)
      17. )
      18. self.out = nn.Linear(32*7*7, 10)
      19. def forward(self, x):
      20. x = self.conv1(x)
      21. x = self.conv2(x)
      22. x = x.view(x.size(0), -1) # (batch, 32*7*7)
      23. out = self.out(x)
      24. return out
      25. def make_dot(var, params=None):
      26. """ Produces Graphviz representation of PyTorch autograd graph
      27. Blue nodes are the Variables that require grad, orange are Tensors
      28. saved for backward in torch.autograd.Function
      29. Args:
      30. var: output Variable
      31. params: dict of (name, Variable) to add names to node that
      32. require grad (TODO: make optional)
      33. """
      34. if params is not None:
      35. assert isinstance(params.values()[0], Variable)
      36. param_map = {id(v): k for k, v in params.items()}
      37. node_attr = dict(style='filled',
      38. shape='box',
      39. align='left',
      40. fontsize='12',
      41. ranksep='0.1',
      42. height='0.2')
      43. dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
      44. seen = set()
      45. def size_to_str(size):
      46. return '('+(', ').join(['%d' % v for v in size])+')'
      47. def add_nodes(var):
      48. if var not in seen:
      49. if torch.is_tensor(var):
      50. dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      51. elif hasattr(var, 'variable'):
      52. u = var.variable
      53. name = param_map[id(u)] if params is not None else ''
      54. node_name = '%s\n %s' % (name, size_to_str(u.size()))
      55. dot.node(str(id(var)), node_name, fillcolor='lightblue')
      56. else:
      57. dot.node(str(id(var)), str(type(var).__name__))
      58. seen.add(var)
      59. if hasattr(var, 'next_functions'):
      60. for u in var.next_functions:
      61. if u[0] is not None:
      62. dot.edge(str(id(u[0])), str(id(var)))
      63. add_nodes(u[0])
      64. if hasattr(var, 'saved_tensors'):
      65. for t in var.saved_tensors:
      66. dot.edge(str(id(t)), str(id(var)))
      67. add_nodes(t)
      68. add_nodes(var.grad_fn)
      69. return dot
      70. if __name__ == '__main__':
      71. net = CNN()
      72. x = Variable(torch.randn(1, 1, 28, 28))
      73. y = net(x)
      74. g = make_dot(y)
      75. g.view()
      76. params = list(net.parameters())
      77. k = 0
      78. for i in params:
      79. l = 1
      80. print("該層的結(jié)構(gòu):" + str(list(i.size())))
      81. for j in i.size():
      82. l *= j
      83. print("該層參數(shù)和:" + str(l))
      84. k = k + l
      85. print("總參數(shù)數(shù)量和:" + str(k))
          模型很簡單,代碼也很簡單。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

          大家在可視化的時候,直接復(fù)制make_dot那段代碼即可,然后需要初始化一個net,以及這個網(wǎng)絡(luò)需要的數(shù)據(jù)規(guī)模,此處就以    這段代碼為例,初始化一個模型net,準備這個模型的輸入數(shù)據(jù)x,shape為(batch,channels,height,width) 然后把數(shù)據(jù)傳入模型得到輸出結(jié)果y。傳入make_dot即可得到下圖。

      1. net = CNN()
      2. x = Variable(torch.randn(1, 1, 28, 28))
      3. y = net(x)
      4. g = make_dot(y)
      5. g.view()
       


          最后輸出該網(wǎng)絡(luò)的各種參數(shù)。

      1. 該層的結(jié)構(gòu):[16, 1, 5, 5]
      2. 該層參數(shù)和:400
      3. 該層的結(jié)構(gòu):[16]
      4. 該層參數(shù)和:16
      5. 該層的結(jié)構(gòu):[32, 16, 5, 5]
      6. 該層參數(shù)和:12800
      7. 該層的結(jié)構(gòu):[32]
      8. 該層參數(shù)和:32
      9. 該層的結(jié)構(gòu):[10, 1568]
      10. 該層參數(shù)和:15680
      11. 該層的結(jié)構(gòu):[10]
      12. 該層參數(shù)和:10
      13. 總參數(shù)數(shù)量和:28938

        本站是提供個人知識管理的網(wǎng)絡(luò)存儲空間,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點。請注意甄別內(nèi)容中的聯(lián)系方式、誘導(dǎo)購買等信息,謹防詐騙。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請點擊一鍵舉報。
        轉(zhuǎn)藏 分享 獻花(0

        0條評論

        發(fā)表

        請遵守用戶 評論公約

        類似文章 更多