利用Pytorch复现AlexNet网络

利用Pytorch复现AlexNet网络

根据吴恩达老师在深度学习课程中的讲解,AlexNet网络的基本流程为:


代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class AlexNet(nn.Module):
    def __init__(self,num_classes):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 96, kernel_size = 11, stride=4, padding=0)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
        self.conv2 = nn.Conv2d(in_channels = 96, out_channels = 256 , kernel_size = 5, stride = 1, padding = 2)
        self.pool2 = nn.MaxPool2d(kernel_size= 3,stride=2,padding=0)
        self.conv3 = nn.Conv2d(in_channels= 256, out_channels= 384,kernel_size= 3,stride=1,padding=1)
        self.conv4 = nn.Conv2d(in_channels=384,out_channels= 384,kernel_size=3,stride=1,padding=1)
        self.conv5 = nn.Conv2d(in_channels=384,out_channels= 256,kernel_size=3,stride=1,padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=3,stride=2,padding=0)
        self.fc1   = nn.Linear(6*6*256,4096)
        self.fc2   = nn.Linear(4096,4096)
        self.fc3   = nn.Linear(4096,num_classes)
    
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool3(F.relu(self.conv5(x)))
        x = x.view(-1, 256 * 6 * 6)
        x = F.dropout(x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x)
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x
        
net = AlexNet(1000)
print(net)

输出:

AlexNet(
  (conv1): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=9216, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=4096, bias=True)
  (fc3): Linear(in_features=4096, out_features=1000, bias=True)
)

全部评论

相关推荐

威猛的小饼干正在背八股:挂到根本不想整理
点赞 评论 收藏
分享
点赞 收藏 评论
分享
牛客网
牛客企业服务