PyTorch_19_网络模型的保存与读取


  • model_save.py
import torch
import torchvision

# 这个是没有经过训练的
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式一:保存模型结构+模型参数
# 这里不仅保存了网络模型的一些结构,也保存了网络模型的一些参数
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式二:模型参数(官方推荐)
# 获取vgg16的状态,并把它保存为字典的形式
torch.save(vgg16.state_dict(), "vgg16_method2.pth")


# 陷阱
class TuDui(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

tudui = TuDui()

torch.save(tudui, "tudui_method1.pth")
  • model_load.py
import torch
import torchvision

#方式一,对应保存方式一
from torch import nn
# from model_save import TuDui

model = torch.load("vgg16_method1.pth")
print(model)

#方式二,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 这里面获取的就是字典
print(torch.load("vgg16_method2.pth"))
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

# 陷阱1
# 这会的时候,还需要把这个模型的类给搬过来
# 但是现在不用写这个了:tudui = TuDui()
# 就是为了确保加载的这个网络模型就是你想要的网络模型
class TuDui(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

model = torch.load("tudui_method1.pth")

print(model)

Author: Ruimin Huang
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint polocy. If reproduced, please indicate source Ruimin Huang !
  TOC