243 lines
7.7 KiB
Python
243 lines
7.7 KiB
Python
"""
|
|
评估工具模块,用于评估模型性能
|
|
"""
|
|
|
|
import numpy as np
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
|
from sklearn.metrics import classification_report, confusion_matrix
|
|
import pandas as pd
|
|
|
|
def evaluate_model(y_true, y_pred, class_names=None):
|
|
"""
|
|
评估模型性能
|
|
|
|
Args:
|
|
y_true: 真实标签(独热编码形式)
|
|
y_pred: 预测概率
|
|
class_names: 类别名称列表
|
|
|
|
Returns:
|
|
评估结果字典
|
|
"""
|
|
# 转换为类别索引
|
|
y_true_classes = np.argmax(y_true, axis=1)
|
|
y_pred_classes = np.argmax(y_pred, axis=1)
|
|
|
|
# 计算混淆矩阵
|
|
cm = confusion_matrix(y_true_classes, y_pred_classes)
|
|
|
|
# 计算准确率
|
|
accuracy = accuracy_score(y_true_classes, y_pred_classes)
|
|
|
|
# 计算精确率、召回率和F1分数
|
|
precision = precision_score(y_true_classes, y_pred_classes, average='weighted', zero_division=0)
|
|
recall = recall_score(y_true_classes, y_pred_classes, average='weighted', zero_division=0)
|
|
f1 = f1_score(y_true_classes, y_pred_classes, average='weighted', zero_division=0)
|
|
|
|
# 计算每个类别的指标
|
|
class_precision = precision_score(y_true_classes, y_pred_classes, average=None, zero_division=0)
|
|
class_recall = recall_score(y_true_classes, y_pred_classes, average=None, zero_division=0)
|
|
class_f1 = f1_score(y_true_classes, y_pred_classes, average=None, zero_division=0)
|
|
|
|
# 整理每个类别的指标
|
|
class_metrics = {}
|
|
|
|
# 如果提供了类别名称,使用类别名称作为键
|
|
if class_names is not None and len(class_names) > 0:
|
|
for i, name in enumerate(class_names):
|
|
if i < len(class_precision):
|
|
class_metrics[name] = {
|
|
'precision': class_precision[i],
|
|
'recall': class_recall[i],
|
|
'f1': class_f1[i]
|
|
}
|
|
else:
|
|
# 否则使用类别索引作为键
|
|
for i in range(len(class_precision)):
|
|
class_metrics[f'class_{i}'] = {
|
|
'precision': class_precision[i],
|
|
'recall': class_recall[i],
|
|
'f1': class_f1[i]
|
|
}
|
|
|
|
# 返回结果字典
|
|
return {
|
|
'accuracy': accuracy,
|
|
'precision': precision,
|
|
'recall': recall,
|
|
'f1': f1,
|
|
'confusion_matrix': cm,
|
|
'class_metrics': class_metrics
|
|
}
|
|
|
|
def print_evaluation_results(results, class_names=None):
|
|
"""
|
|
打印评估结果
|
|
|
|
Args:
|
|
results: 评估结果字典
|
|
class_names: 类别名称列表
|
|
"""
|
|
print("=" * 50)
|
|
print("模型评估结果:")
|
|
print("=" * 50)
|
|
|
|
print(f"准确率 (Accuracy): {results['accuracy']:.4f}")
|
|
print(f"精确率 (Precision): {results['precision']:.4f}")
|
|
print(f"召回率 (Recall): {results['recall']:.4f}")
|
|
print(f"F1分数 (F1): {results['f1']:.4f}")
|
|
|
|
print("\n混淆矩阵:")
|
|
print(results['confusion_matrix'])
|
|
|
|
if 'class_metrics' in results:
|
|
print("\n每个类别的指标:")
|
|
for class_name, metrics in results['class_metrics'].items():
|
|
print(f"{class_name}:")
|
|
print(f" 精确率 (Precision): {metrics['precision']:.4f}")
|
|
print(f" 召回率 (Recall): {metrics['recall']:.4f}")
|
|
print(f" F1分数 (F1): {metrics['f1']:.4f}")
|
|
|
|
print("=" * 50)
|
|
|
|
def evaluate_by_language(y_true, y_pred, language, class_names=None):
|
|
"""
|
|
按语言分类评估模型性能
|
|
|
|
Args:
|
|
y_true: 真实标签(数值或独热编码)
|
|
y_pred: 预测标签(数值或预测概率)
|
|
language: 语言标签数组
|
|
class_names: 类别名称列表
|
|
|
|
Returns:
|
|
按语言分类的评估结果字典
|
|
"""
|
|
# 如果输入是独热编码,转换为类别索引
|
|
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
|
y_true = np.argmax(y_true, axis=1)
|
|
|
|
# 如果输入是预测概率,转换为类别索引
|
|
if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
|
|
y_pred = np.argmax(y_pred, axis=1)
|
|
|
|
# 获取所有语言类型
|
|
unique_languages = np.unique(language)
|
|
|
|
results = {}
|
|
|
|
# 对每种语言分别计算评估指标
|
|
for lang in unique_languages:
|
|
# 筛选特定语言的样本
|
|
mask = (language == lang)
|
|
if np.sum(mask) == 0:
|
|
continue
|
|
|
|
lang_y_true = y_true[mask]
|
|
lang_y_pred = y_pred[mask]
|
|
|
|
# 计算该语言的评估指标
|
|
lang_results = evaluate_model(lang_y_true, lang_y_pred, class_names)
|
|
results[lang] = lang_results
|
|
|
|
return results
|
|
|
|
def print_evaluation_by_language(results, languages=None):
|
|
"""
|
|
打印按语言分类的评估结果
|
|
|
|
Args:
|
|
results: 按语言分类的评估结果字典
|
|
languages: 语言名称字典,将语言代码映射到语言名称
|
|
"""
|
|
if languages is None:
|
|
languages = {
|
|
'zh': '中文',
|
|
'en': '英文'
|
|
}
|
|
|
|
for lang, lang_results in results.items():
|
|
lang_name = languages.get(lang, lang)
|
|
print("=" * 50)
|
|
print(f"{lang_name}数据评估结果:")
|
|
print("=" * 50)
|
|
print_evaluation_results(lang_results)
|
|
|
|
def get_top_n_predictions(probabilities, class_names, n=3):
|
|
"""
|
|
获取概率最高的前N个预测结果
|
|
|
|
Args:
|
|
probabilities: 预测概率数组
|
|
class_names: 类别名称列表
|
|
n: 获取前N个结果
|
|
|
|
Returns:
|
|
包含前N个预测及其概率的字典
|
|
"""
|
|
# 找到前N个最高概率的索引
|
|
top_n_indices = np.argsort(probabilities)[-n:][::-1]
|
|
|
|
# 获取对应的类别名称和概率
|
|
top_n_classes = [class_names[i] for i in top_n_indices]
|
|
top_n_probs = [probabilities[i] for i in top_n_indices]
|
|
|
|
# 构建结果
|
|
result = {
|
|
'classes': top_n_classes,
|
|
'probabilities': top_n_probs
|
|
}
|
|
|
|
return result
|
|
|
|
def get_emotion_accuracy_by_speaker(y_true, y_pred, speaker_ids, emotions):
|
|
"""
|
|
计算每个说话者在每种情感上的准确率
|
|
|
|
Args:
|
|
y_true: 真实标签
|
|
y_pred: 预测标签
|
|
speaker_ids: 说话者ID数组
|
|
emotions: 情感标签数组
|
|
|
|
Returns:
|
|
说话者-情感准确率矩阵
|
|
"""
|
|
# 如果输入是独热编码,转换为类别索引
|
|
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
|
y_true = np.argmax(y_true, axis=1)
|
|
|
|
# 如果输入是预测概率,转换为类别索引
|
|
if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
|
|
y_pred = np.argmax(y_pred, axis=1)
|
|
|
|
# 获取所有说话者ID和情感类别
|
|
unique_speakers = np.unique(speaker_ids)
|
|
unique_emotions = np.unique(emotions)
|
|
|
|
# 创建准确率矩阵
|
|
accuracy_matrix = np.zeros((len(unique_speakers), len(unique_emotions)))
|
|
|
|
# 计算每个说话者在每种情感上的准确率
|
|
for i, speaker in enumerate(unique_speakers):
|
|
for j, emotion in enumerate(unique_emotions):
|
|
# 筛选特定说话者和情感的样本
|
|
mask = (speaker_ids == speaker) & (emotions == emotion)
|
|
if np.sum(mask) == 0:
|
|
accuracy_matrix[i, j] = np.nan
|
|
continue
|
|
|
|
# 计算准确率
|
|
speaker_emotion_true = y_true[mask]
|
|
speaker_emotion_pred = y_pred[mask]
|
|
accuracy = accuracy_score(speaker_emotion_true, speaker_emotion_pred)
|
|
accuracy_matrix[i, j] = accuracy
|
|
|
|
# 创建DataFrame
|
|
accuracy_df = pd.DataFrame(
|
|
accuracy_matrix,
|
|
index=unique_speakers,
|
|
columns=unique_emotions
|
|
)
|
|
|
|
return accuracy_df |