有人用过transformer框架么,请教一个问题

代码如下:
from transformers import *
import torch 
import logging
logging.basicConfig(level=logging.INFO)
bert_model_path = "../pretrain_model/bert_base_cased"
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
model = BertForSequenceClassification.from_pretrained(bert_model_path)
classes = ["not paraphrase", "is paraphrase"]
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="pt")
paraphrase_classification_logits = model(**paraphrase)[0]
not_paraphrase_classification_logits = model(**not_paraphrase)[0]
paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0]
print("Should be paraphrase")
for i in range(len(classes)):
    print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%")
print("\nShould not be paraphrase")
for i in range(len(classes)):
    print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")
预期输出格式:
Should be paraphrase not paraphrase: 10% is paraphrase: 90% Should not be paraphrase not paraphrase: 94% is paraphrase: 6% 
在服务器上输出:
weight.t() size:  torch.Size([768, 3072])
input size:  torch.Size([1, 21, 3072])
weight.t() size:  torch.Size([3072, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 768])
input size:  torch.Size([1, 21, 768])
weight.t() size:  torch.Size([768, 3072])
input size:  torch.Size([1, 21, 3072])
weight.t() size:  torch.Size([3072, 768])
#还有很多行像上面一样格式数据输出 Should be paraphrase not paraphrase: 10%  is paraphrase: 90%  Should not be paraphrase  not paraphrase: 94%  is paraphrase: 6% 
求问中间输出是怎么回事,我并没有主动打印上面的信息啊,我换了好几个模型都是这样子,请问有人碰到过这样的情况么。
我去看transformer的loging信息,但是并没有发现有weight.t()这样的信息打印。去github上看issue也没发现有人碰到这样的问题。
google直接搜不到
###救救孩子吧,开学不出成果会被延毕啊啊啊啊啊。



#算法题目求助#
全部评论
不是框架的问题,一点点debug吧,要有耐心
点赞
送花
回复 分享
发布于 2020-07-01 19:51
我以为我发了个帖子。。。
点赞
送花
回复 分享
发布于 2020-07-02 20:02
元戎启行
校招火热招聘中
官网直投
进到model = BertForSequenceClassification.from_pretrained(bert_model_path) 这里看看有没有什么打印的信息
点赞
送花
回复 分享
发布于 2020-07-02 20:08
是**的cuda版本的锅 更新到 cuda10就好了
点赞
送花
回复 分享
发布于 2020-07-06 14:28

相关推荐

打开招聘网站会发现,算法岗种类很多。🌟 有计算机视觉,NLP,搜广推,风控,运筹,数据挖掘,机器学习,通信等等。那这些算法岗有什么区别呢?千呼万唤使出来,今天来分享一下搜广推。搜广推可以说是支撑了互联网业务逻辑的核心,即提升流量分发效率。重要性不言而喻。搜广推由于和业务十分接近,导致各家大厂需求很多。然而,搜广推由于多年发展,在成熟业务里已经很难取得提升,各家大厂都在寻找新的范式/新的问题/新的指标来提升算法效果。因此,搜广推的工作内容已经进入了严重内卷的阶段。那搜广推还值得应届生进入吗?搜广推仍然是应届生进入算法岗的首选方向之一。因为,搜广推和业务强绑定,并且是互联网赚钱的核心,需求会存在很长一段时间,甚至它的保质期约等于公司业务的发展周期。此外,搜广推和业务场景强相关,日积月累都是经验,不但是技术经验,更是产品经验。越往后发展,产品经验越重要,很多产品经理大佬也都是搜广推算法转行的。搜广推我比较看好的方向是用户价值建模,多场景建模以及多任务建模。更好的挖掘现有用户价值,提升留存,增强场景渗透率,增加更多的长期指标都是目前的趋势,而上述方向可以很好的提升这些目标。总结,搜广推和互联网业务逻辑强绑定,如果互联网运行模式不变,搜广推就会有需求。 #校招过来人的经验分享#
点赞 评论 收藏
分享
点赞 收藏 评论
分享
牛客网
牛客企业服务