PyTorch_18_现有模型的使用和修改


import torchvision

# train_data = torchvision.datasets.ImageNet("./dataset", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor)

# 这个使用的是随机参数
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
# 这个是加载了训练好的参数
vgg16_true = torchvision.models.vgg16(pretrained=True)
print('ok')

# 可以看出,这是分成了1000多个类的
# print(vgg16_true)


train_data = torchvision.datasets.CIFAR10("./dataset", train=True, download=True,
                                          transform=torchvision.transforms.ToTensor)
# 方式一:把1000改成10
# 方式二:再加一层,让输入是1000,输出是10
# 通过现有的网络,改变他的结构,满足我们的要求
# 很多框架都会把vgg16当成一个前置的网络

# 添加的方式
# vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
# vgg16_true.features.add_module('add_linear', nn.Linear(1000, 10))
# vgg16_true.avgpool.add_module('add_linear', nn.Linear(1000, 10))
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)

# 修改的方式
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

Author: Ruimin Huang
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint polocy. If reproduced, please indicate source Ruimin Huang !
  TOC