import torchvision
# train_data = torchvision.datasets.ImageNet("./dataset", split='train', download=True,
# transform=torchvision.transforms.ToTensor)
# 这个使用的是随机参数
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained=False)
# 这个是加载了训练好的参数
vgg16_true = torchvision.models.vgg16(pretrained=True)
print('ok')
# 可以看出,这是分成了1000多个类的
# print(vgg16_true)
train_data = torchvision.datasets.CIFAR10("./dataset", train=True, download=True,
transform=torchvision.transforms.ToTensor)
# 方式一:把1000改成10
# 方式二:再加一层,让输入是1000,输出是10
# 通过现有的网络,改变他的结构,满足我们的要求
# 很多框架都会把vgg16当成一个前置的网络
# 添加的方式
# vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
# vgg16_true.features.add_module('add_linear', nn.Linear(1000, 10))
# vgg16_true.avgpool.add_module('add_linear', nn.Linear(1000, 10))
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)
# 修改的方式
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)
Previous

2021-11-19
Next

2021-11-17