迁移学习-预训练模型的保存与加载
1,模型保存和读取:
# 如果要保存最好的参数,使用: best_model_state = deepcopy(model.state_dict())
model_save_path = os.path.join('', 'model.pt')
torch.save(model.state_dict(), model_save_path)
# 模型参数读取
model = LeNet5()
model_save_path = os.path.join(model_save_dir, 'model.pt')
if os.path.exists(model_save_path):
loaded_paras = torch.load(model_save_path)
model.load_state_dict(loaded_paras)
# 也可以保存优化器等:
# model_save_path = os.path.join(model_save_dir, 'model.pt')
# torch.save({
# 'epoch': epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'loss': loss,
# }, model_save_path)
# 读取:
checkpoint = torch.load(model_save_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch'] 5 loss = checkpoint['loss']
2,迁移学习
按照上面的方法对模型进行保存和读取,当迁移的模型部分不同时,可以根据参数名称和大小,选择性的保留读取进来的参数:
def para_state_dict(model, model_save_dir):
state_dict = deepcopy(model.state_dict())
model_save_path = os.path.join(model_save_dir, 'model.pt')
if os.path.exists(model_save_path):
loaded_paras = torch.load(model_save_path)
for key in state_dict: # 在新的网络模型中遍历对应参数
if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size():
print("成功初始化参数:", key)
state_dict[key] = loaded_paras[key]
return state_dict