1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
| import torch from MyData import MyDataset from torch.utils.data import DataLoader from net import Model from transformers import BertTokenizer from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score from sklearn.metrics import confusion_matrix, classification_report import matplotlib.pyplot as plt import seaborn as sns import numpy as np import os
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
token = BertTokenizer.from_pretrained( r"\models\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f")
def collate_fn(data): sents = [i[0] for i in data] label = [i[1] for i in data] data = token.batch_encode_plus( batch_text_or_text_pairs=sents, truncation=True, max_length=512, padding="max_length", return_tensors="pt", return_length=True ) input_ids = data["input_ids"] attention_mask = data["attention_mask"] token_type_ids = data["token_type_ids"] label = torch.LongTensor(label) return input_ids, attention_mask, token_type_ids, label def evaluate_model(model, test_loader, device): """ 评估模型在测试集上的性能 :param model: 待评估模型 :param test_loader: 测试数据加载器 :param device: 计算设备 :return: 评估指标字典 """ model.eval() all_preds, all_labels = [], [] for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(test_loader): input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) token_type_ids = token_type_ids.to(device) labels = labels.to(device) with torch.no_grad(): outputs = model(input_ids, attention_mask, token_type_ids) preds = torch.argmax(outputs, dim=1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) metrics = { 'accuracy': accuracy_score(all_labels, all_preds), 'precision_macro': precision_score(all_labels, all_preds, average='macro'), 'recall_macro': recall_score(all_labels, all_preds, average='macro'), 'f1_macro': f1_score(all_labels, all_preds, average='macro'), 'precision_weighted': precision_score(all_labels, all_preds, average='weighted'), 'recall_weighted': recall_score(all_labels, all_preds, average='weighted'), 'f1_weighted': f1_score(all_labels, all_preds, average='weighted'), 'confusion_matrix': confusion_matrix(all_labels, all_preds), 'classification_report': classification_report(all_labels, all_preds, digits=4) } return metrics def plot_confusion_matrix(cm, class_names, save_path=None): """ 绘制并保存混淆矩阵 :param cm: 混淆矩阵 :param class_names: 类别名称列表 :param save_path: 保存路径(可选) """ plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.xlabel('预测标签') plt.ylabel('真实标签') plt.title('混淆矩阵') if save_path: plt.savefig(save_path, bbox_inches='tight') print(f"混淆矩阵已保存至: {save_path}") plt.show() def save_metrics_to_file(metrics, save_path): """ 将评估指标保存到文本文件 :param metrics: 评估指标字典 :param save_path: 保存路径 """ with open(save_path, 'w', encoding='utf-8') as f: f.write("模型评估报告\n") f.write("=" * 50 + "\n") f.write(f"准确率 (Accuracy): {metrics['accuracy']:.4f}\n\n") f.write("宏平均指标 (Macro-average):\n") f.write(f" 精确率 (Precision): {metrics['precision_macro']:.4f}\n") f.write(f" 召回率 (Recall): {metrics['recall_macro']:.4f}\n") f.write(f" F1分数 (F1 Score): {metrics['f1_macro']:.4f}\n\n") f.write("加权平均指标 (Weighted-average):\n") f.write(f" 精确率 (Precision): {metrics['precision_weighted']:.4f}\n") f.write(f" 召回率 (Recall): {metrics['recall_weighted']:.4f}\n") f.write(f" F1分数 (F1 Score): {metrics['f1_weighted']:.4f}\n\n") f.write("分类报告 (Classification Report):\n") f.write(metrics['classification_report']) f.write("\n\n混淆矩阵 (Confusion Matrix):\n") np.savetxt(f, metrics['confusion_matrix'], fmt='%d') print(f"评估报告已保存至: {save_path}") if __name__ == '__main__': test_dataset = MyDataset("test") test_loader = DataLoader( dataset=test_dataset, batch_size=100, shuffle=False, drop_last=False, collate_fn=collate_fn ) print(f"使用设备: {DEVICE}") model = Model().to(DEVICE) model_path = "params/best_bert.pth" if not os.path.exists(model_path): raise FileNotFoundError(f"模型参数文件不存在: {model_path}") model.load_state_dict(torch.load(model_path)) metrics = evaluate_model(model, test_loader, DEVICE) print("\n" + "=" * 50) print(f"准确率 (Accuracy): {metrics['accuracy']:.4f}") print("\n宏平均指标 (Macro-average):") print(f" 精确率 (Precision): {metrics['precision_macro']:.4f}") print(f" 召回率 (Recall): {metrics['recall_macro']:.4f}") print(f" F1分数 (F1 Score): {metrics['f1_macro']:.4f}") print("\n加权平均指标 (Weighted-average):") print(f" 精确率 (Precision): {metrics['precision_weighted']:.4f}") print(f" 召回率 (Recall): {metrics['recall_weighted']:.4f}") print(f" F1分数 (F1 Score): {metrics['f1_weighted']:.4f}") print("\n分类报告 (Classification Report):") print(metrics['classification_report']) class_names = ["类别0", "类别1"] plot_confusion_matrix(metrics['confusion_matrix'], class_names, "confusion_matrix.png") save_metrics_to_file(metrics, "evaluation_report.txt") print("评估完成!")
|