有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。
4.5.1 存取Tensor
我们可以直接使用save函数和load函数分别存储和读取Tensor。
实质上是对pickle模块的一层封装
import torchfrom torch import nn# 存储x = torch.ones(3)torch.save(x, "x.pt")# 读取y = torch.load("x.pt")print(y)# Tensor 列表的存取x = torch.zeros(3)y = torch.ones(4)torch.save([x, y], "xy_list.pt")xy_list = torch.load("xy_list.pt")print(xy_list)# 存储并读取一个从字符串映射到 Tensor 的字典torch.save({'x': x, 'y': y}, "xy_dict.pt")xy_dict = torch.load("xy_dict.pt")print(xy_dict)
运行结果
tensor([1., 1., 1.])[tensor([0., 0., 0.]), tensor([1., 1., 1., 1.])]{'x': tensor([0., 0., 0.]), 'y': tensor([1., 1., 1., 1.])}
4.5.2 存取模型
4.5.2.1 state_dict
在PyTorch中,Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称隐射到参数Tesnor的字典对象。
lass MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)# state_dict是一个从参数名称隐射到参数 Tesnor 的字典对象。net = MLP()print(net.state_dict())for i in net.state_dict().items():print(i)
运行结果
OrderedDict([('hidden.weight', tensor([[-0.0567, 0.1704, 0.3017],[ 0.5385, 0.5153, -0.3220]])), ('hidden.bias', tensor([-0.4603, -0.4822])), ('output.weight', tensor([[-0.5111, -0.7002]])), ('output.bias', tensor([-0.3096]))])('hidden.weight', tensor([[-0.0567, 0.1704, 0.3017],[ 0.5385, 0.5153, -0.3220]]))('hidden.bias', tensor([-0.4603, -0.4822]))('output.weight', tensor([[-0.5111, -0.7002]]))('output.bias', tensor([-0.3096]))
只有具有可学习参数的层(卷积层、线性层等)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)print(optimizer.state_dict())
运行结果
{'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140603236014432, 140603236015072, 140602887169280, 140602887160128]}]}
4.5.2.2 两种存取模型的方式
存取模型有两种方式:仅存取模型参数和存取整个模型。前者是推荐用法。
# 仅存取模型参数 state_dict, 推荐torch.save(net.state_dict(), "state_dict.pt")model = MLP()model.load_state_dict(torch.load("state_dict.pt"))print(model)# 存取整个模型, 即结构 + 参数torch.save(net, "whole_model.pt")model = torch.load("whole_model.pt")print(model)
运行结果
MLP((hidden): Linear(in_features=3, out_features=2, bias=True)(act): ReLU()(output): Linear(in_features=2, out_features=1, bias=True))MLP((hidden): Linear(in_features=3, out_features=2, bias=True)(act): ReLU()(output): Linear(in_features=2, out_features=1, bias=True))/home/luzhan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type MLP. It won't be checked for correctness upon loading."type " + obj.__name__ + ". It won't be checked "/home/luzhan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Linear. It won't be checked for correctness upon loading."type " + obj.__name__ + ". It won't be checked "/home/luzhan/anaconda3/lib/python3.7/site-packages/torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type ReLU. It won't be checked for correctness upon loading."type " + obj.__name__ + ". It won't be checked "
4.5.2.3 简单实例
# 简单实践x = torch.randn(2, 3)y = net(x)PATH = "./test.pt"torch.save(net.state_dict(), PATH)net2 = MLP()net2.load_state_dict(torch.load(PATH))y2 = net2(x)print(y2 == y)
因为模型参数相等,因此使用同一种模型实例化后的计算结果也是相等的。
tensor([[True],[True]])
4.5.3 其他场景
例如GPU与CPU之间的模型保存与读取、使用多块GPU的模型的存储等等,使用的时候可以参考官方文档。
