有人用过transformer框架么,请教一个问题
代码如下:
from transformers import * import torch import logging logging.basicConfig(level=logging.INFO) bert_model_path = "../pretrain_model/bert_base_cased" tokenizer = BertTokenizer.from_pretrained(bert_model_path) model = BertForSequenceClassification.from_pretrained(bert_model_path) classes = ["not paraphrase", "is paraphrase"] sequence_0 = "The company HuggingFace is based in New York City" sequence_1 = "Apples are especially bad for your health" sequence_2 = "HuggingFace's headquarters are situated in Manhattan" paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, return_tensors="pt") not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, return_tensors="pt") paraphrase_classification_logits = model(**paraphrase)[0] not_paraphrase_classification_logits = model(**not_paraphrase)[0] paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0] not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0] print("Should be paraphrase") for i in range(len(classes)): print(f"{classes[i]}: {round(paraphrase_results[i] * 100)}%") print("\nShould not be paraphrase") for i in range(len(classes)): print(f"{classes[i]}: {round(not_paraphrase_results[i] * 100)}%")预期输出格式:
Should be paraphrase not paraphrase: 10% is paraphrase: 90% Should not be paraphrase not paraphrase: 94% is paraphrase: 6%在服务器上输出:
weight.t() size: torch.Size([768, 3072]) input size: torch.Size([1, 21, 3072]) weight.t() size: torch.Size([3072, 768]) input size: torch.Size([1, 21, 768]) weight.t() size: torch.Size([768, 768]) input size: torch.Size([1, 21, 768]) weight.t() size: torch.Size([768, 768]) input size: torch.Size([1, 21, 768]) weight.t() size: torch.Size([768, 768]) input size: torch.Size([1, 21, 768]) weight.t() size: torch.Size([768, 768]) input size: torch.Size([1, 21, 768]) weight.t() size: torch.Size([768, 3072]) input size: torch.Size([1, 21, 3072]) weight.t() size: torch.Size([3072, 768]) #还有很多行像上面一样格式数据输出 Should be paraphrase not paraphrase: 10% is paraphrase: 90% Should not be paraphrase not paraphrase: 94% is paraphrase: 6%求问中间输出是怎么回事,我并没有主动打印上面的信息啊,我换了好几个模型都是这样子,请问有人碰到过这样的情况么。
我去看transformer的loging信息,但是并没有发现有weight.t()这样的信息打印。去github上看issue也没发现有人碰到这样的问题。
google直接搜不到
###救救孩子吧,开学不出成果会被延毕啊啊啊啊啊。