校园春色亚洲色图_亚洲视频分类_中文字幕精品一区二区精品_麻豆一区区三区四区产品精品蜜桃

主頁 > 知識庫 > 解決Pytorch修改預訓練模型時遇到key不匹配的情況

解決Pytorch修改預訓練模型時遇到key不匹配的情況

熱門標簽:商家地圖標注海報 騰訊地圖標注沒法顯示 孝感營銷電話機器人效果怎么樣 海外網吧地圖標注注冊 打電話機器人營銷 ai電銷機器人的優勢 地圖標注自己和別人標注區別 聊城語音外呼系統 南陽打電話機器人

一、Pytorch修改預訓練模型時遇到key不匹配

最近想著修改網絡的預訓練模型vgg.pth,但是發現當我加載預訓練模型權重到新建的模型并保存之后。

在我使用新賦值的網絡模型時出現了key不匹配的問題

#加載后保存(未修改網絡)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 將新保存的網絡代替之前的預訓練模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet為ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights) 

此時會如下出錯誤:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

說明之前的預訓練模型 key參數為"0.weight", “0.bias”,但是經過加載保存之后變為了"vgg.0.weight", “vgg.0.bias”

我認為是因為本身的模型定義文件里self.vgg = nn.ModuleList(base)這一句。

現在的問題是因為自己定義保存的模型key參數多了一個前綴。

可以通過如下語句進行修改,并加載

from collections import OrderedDict   #導入此模塊
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面幾位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict) 

此時就不會再出錯了。

參考了這個篇。修改一下就可以應用到自己的模型啦。

//www.jb51.net/article/214214.htm

二、pytorch加載預訓練模型遇到的問題:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加載resnet預訓練模型時,遇到的一個問題,在此記錄一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其實是使用的版本的問題,pytorch0.4.1之后在BN層加入了track_running_stats這個參數,

這個參數的作用如下:

訓練時用來統計訓練時的forward過的min-batch數目,每經過一個min-batch, track_running_stats+=1

如果沒有指定momentum, 則使用1/num_batches_tracked 作為因數來計算均值和方差(running mean and variance).

其實,這個參數沒啥用.但因為官方提供的預訓練模型是pytorch0.3版本訓練出來的,因此沒有這個參數.

所以,只要過濾一下預訓練權重字典中的關鍵字即可,‘num_batches_tracked'.代碼例子,如下.

有問題的代碼:

   def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

對'num_batches_tracked進行過濾:

   def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • Pytorch通過保存為ONNX模型轉TensorRT5的實現
  • pytorch_pretrained_bert如何將tensorflow模型轉化為pytorch模型
  • pytorch模型的保存和加載、checkpoint操作
  • PyTorch 如何檢查模型梯度是否可導
  • pytorch 預訓練模型讀取修改相關參數的填坑問題
  • PyTorch模型轉TensorRT是怎么實現的?

標簽:楊凌 迪慶 撫州 聊城 六盤水 南寧 揚州 牡丹江

巨人網絡通訊聲明:本文標題《解決Pytorch修改預訓練模型時遇到key不匹配的情況》,本文關鍵詞  解決,Pytorch,修改,預,訓練,;如發現本文內容存在版權問題,煩請提供相關信息告之我們,我們將及時溝通與處理。本站內容系統采集于網絡,涉及言論、版權與本站無關。
  • 相關文章
  • 下面列出與本文章《解決Pytorch修改預訓練模型時遇到key不匹配的情況》相關的同類信息!
  • 本頁收集關于解決Pytorch修改預訓練模型時遇到key不匹配的情況的相關信息資訊供網民參考!
  • 推薦文章
    主站蜘蛛池模板: 满城县| 贵定县| 南安市| 新乡市| 裕民县| 平和县| 澄城县| 滁州市| 大兴区| 宁河县| 定南县| 合作市| 社会| 容城县| 射阳县| 故城县| 柳林县| 合江县| 白玉县| 昌宁县| 东方市| 灵武市| 凉城县| 乌兰县| 石门县| 湘西| 永定县| 重庆市| 正镶白旗| 天镇县| 芜湖市| 武胜县| 松溪县| 通许县| 宁化县| 云和县| 冀州市| 宜阳县| 尖扎县| 新源县| 射阳县|