迁移学习-预训练模型的保存与加载

1,模型保存和读取:

# 如果要保存最好的参数,使用: best_model_state = deepcopy(model.state_dict())
model_save_path = os.path.join('', 'model.pt')
torch.save(model.state_dict(), model_save_path)

# 模型参数读取
model = LeNet5()
model_save_path = os.path.join(model_save_dir, 'model.pt')
if os.path.exists(model_save_path):
    loaded_paras = torch.load(model_save_path)
    model.load_state_dict(loaded_paras)
    
# 也可以保存优化器等:
# model_save_path = os.path.join(model_save_dir, 'model.pt')
# torch.save({
# 'epoch': epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'loss': loss,
# }, model_save_path)

# 读取:
checkpoint = torch.load(model_save_path) 
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 
epoch = checkpoint['epoch'] 5 loss = checkpoint['loss']

2,迁移学习

按照上面的方法对模型进行保存和读取,当迁移的模型部分不同时,可以根据参数名称和大小,选择性的保留读取进来的参数:

def para_state_dict(model, model_save_dir): 
state_dict = deepcopy(model.state_dict()) 
model_save_path = os.path.join(model_save_dir, 'model.pt') 
if os.path.exists(model_save_path): 
    loaded_paras = torch.load(model_save_path) 
    for key in state_dict: # 在新的网络模型中遍历对应参数 
        if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size(): 
        print("成功初始化参数:", key) 
        state_dict[key] = loaded_paras[key] 
return state_dict
 
全部评论

相关推荐

湫湫湫不会java:先投着吧,大概率找不到实习,没实习的时候再加个项目,然后把个人评价和荣誉奖项删了,赶紧成为八股战神吧,没实习没学历,秋招机会估计不多,把握机会。或者说秋招时间去冲实习,春招冲offer,但是压力会比较大
点赞 评论 收藏
分享
牛客90772103...:格林美(无锡)
点赞 评论 收藏
分享
不愿透露姓名的神秘牛友
07-04 14:23
steelhead:你回的有问题,让人感觉你就是来学习的
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务