chuban
This commit is contained in:
4
utils/__init__.py
Normal file
4
utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
工具函数包
|
||||
包含可视化和评估功能
|
||||
"""
|
||||
BIN
utils/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/audio_visualizer.cpython-312.pyc
Normal file
BIN
utils/__pycache__/audio_visualizer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/evaluation.cpython-312.pyc
Normal file
BIN
utils/__pycache__/evaluation.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/visualize.cpython-312.pyc
Normal file
BIN
utils/__pycache__/visualize.cpython-312.pyc
Normal file
Binary file not shown.
243
utils/evaluation.py
Normal file
243
utils/evaluation.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
评估工具模块,用于评估模型性能
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
import pandas as pd
|
||||
|
||||
def evaluate_model(y_true, y_pred, class_names=None):
|
||||
"""
|
||||
评估模型性能
|
||||
|
||||
Args:
|
||||
y_true: 真实标签(独热编码形式)
|
||||
y_pred: 预测概率
|
||||
class_names: 类别名称列表
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
# 转换为类别索引
|
||||
y_true_classes = np.argmax(y_true, axis=1)
|
||||
y_pred_classes = np.argmax(y_pred, axis=1)
|
||||
|
||||
# 计算混淆矩阵
|
||||
cm = confusion_matrix(y_true_classes, y_pred_classes)
|
||||
|
||||
# 计算准确率
|
||||
accuracy = accuracy_score(y_true_classes, y_pred_classes)
|
||||
|
||||
# 计算精确率、召回率和F1分数
|
||||
precision = precision_score(y_true_classes, y_pred_classes, average='weighted', zero_division=0)
|
||||
recall = recall_score(y_true_classes, y_pred_classes, average='weighted', zero_division=0)
|
||||
f1 = f1_score(y_true_classes, y_pred_classes, average='weighted', zero_division=0)
|
||||
|
||||
# 计算每个类别的指标
|
||||
class_precision = precision_score(y_true_classes, y_pred_classes, average=None, zero_division=0)
|
||||
class_recall = recall_score(y_true_classes, y_pred_classes, average=None, zero_division=0)
|
||||
class_f1 = f1_score(y_true_classes, y_pred_classes, average=None, zero_division=0)
|
||||
|
||||
# 整理每个类别的指标
|
||||
class_metrics = {}
|
||||
|
||||
# 如果提供了类别名称,使用类别名称作为键
|
||||
if class_names is not None and len(class_names) > 0:
|
||||
for i, name in enumerate(class_names):
|
||||
if i < len(class_precision):
|
||||
class_metrics[name] = {
|
||||
'precision': class_precision[i],
|
||||
'recall': class_recall[i],
|
||||
'f1': class_f1[i]
|
||||
}
|
||||
else:
|
||||
# 否则使用类别索引作为键
|
||||
for i in range(len(class_precision)):
|
||||
class_metrics[f'class_{i}'] = {
|
||||
'precision': class_precision[i],
|
||||
'recall': class_recall[i],
|
||||
'f1': class_f1[i]
|
||||
}
|
||||
|
||||
# 返回结果字典
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1,
|
||||
'confusion_matrix': cm,
|
||||
'class_metrics': class_metrics
|
||||
}
|
||||
|
||||
def print_evaluation_results(results, class_names=None):
|
||||
"""
|
||||
打印评估结果
|
||||
|
||||
Args:
|
||||
results: 评估结果字典
|
||||
class_names: 类别名称列表
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("模型评估结果:")
|
||||
print("=" * 50)
|
||||
|
||||
print(f"准确率 (Accuracy): {results['accuracy']:.4f}")
|
||||
print(f"精确率 (Precision): {results['precision']:.4f}")
|
||||
print(f"召回率 (Recall): {results['recall']:.4f}")
|
||||
print(f"F1分数 (F1): {results['f1']:.4f}")
|
||||
|
||||
print("\n混淆矩阵:")
|
||||
print(results['confusion_matrix'])
|
||||
|
||||
if 'class_metrics' in results:
|
||||
print("\n每个类别的指标:")
|
||||
for class_name, metrics in results['class_metrics'].items():
|
||||
print(f"{class_name}:")
|
||||
print(f" 精确率 (Precision): {metrics['precision']:.4f}")
|
||||
print(f" 召回率 (Recall): {metrics['recall']:.4f}")
|
||||
print(f" F1分数 (F1): {metrics['f1']:.4f}")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
def evaluate_by_language(y_true, y_pred, language, class_names=None):
|
||||
"""
|
||||
按语言分类评估模型性能
|
||||
|
||||
Args:
|
||||
y_true: 真实标签(数值或独热编码)
|
||||
y_pred: 预测标签(数值或预测概率)
|
||||
language: 语言标签数组
|
||||
class_names: 类别名称列表
|
||||
|
||||
Returns:
|
||||
按语言分类的评估结果字典
|
||||
"""
|
||||
# 如果输入是独热编码,转换为类别索引
|
||||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||||
y_true = np.argmax(y_true, axis=1)
|
||||
|
||||
# 如果输入是预测概率,转换为类别索引
|
||||
if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
|
||||
y_pred = np.argmax(y_pred, axis=1)
|
||||
|
||||
# 获取所有语言类型
|
||||
unique_languages = np.unique(language)
|
||||
|
||||
results = {}
|
||||
|
||||
# 对每种语言分别计算评估指标
|
||||
for lang in unique_languages:
|
||||
# 筛选特定语言的样本
|
||||
mask = (language == lang)
|
||||
if np.sum(mask) == 0:
|
||||
continue
|
||||
|
||||
lang_y_true = y_true[mask]
|
||||
lang_y_pred = y_pred[mask]
|
||||
|
||||
# 计算该语言的评估指标
|
||||
lang_results = evaluate_model(lang_y_true, lang_y_pred, class_names)
|
||||
results[lang] = lang_results
|
||||
|
||||
return results
|
||||
|
||||
def print_evaluation_by_language(results, languages=None):
|
||||
"""
|
||||
打印按语言分类的评估结果
|
||||
|
||||
Args:
|
||||
results: 按语言分类的评估结果字典
|
||||
languages: 语言名称字典,将语言代码映射到语言名称
|
||||
"""
|
||||
if languages is None:
|
||||
languages = {
|
||||
'zh': '中文',
|
||||
'en': '英文'
|
||||
}
|
||||
|
||||
for lang, lang_results in results.items():
|
||||
lang_name = languages.get(lang, lang)
|
||||
print("=" * 50)
|
||||
print(f"{lang_name}数据评估结果:")
|
||||
print("=" * 50)
|
||||
print_evaluation_results(lang_results)
|
||||
|
||||
def get_top_n_predictions(probabilities, class_names, n=3):
|
||||
"""
|
||||
获取概率最高的前N个预测结果
|
||||
|
||||
Args:
|
||||
probabilities: 预测概率数组
|
||||
class_names: 类别名称列表
|
||||
n: 获取前N个结果
|
||||
|
||||
Returns:
|
||||
包含前N个预测及其概率的字典
|
||||
"""
|
||||
# 找到前N个最高概率的索引
|
||||
top_n_indices = np.argsort(probabilities)[-n:][::-1]
|
||||
|
||||
# 获取对应的类别名称和概率
|
||||
top_n_classes = [class_names[i] for i in top_n_indices]
|
||||
top_n_probs = [probabilities[i] for i in top_n_indices]
|
||||
|
||||
# 构建结果
|
||||
result = {
|
||||
'classes': top_n_classes,
|
||||
'probabilities': top_n_probs
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_emotion_accuracy_by_speaker(y_true, y_pred, speaker_ids, emotions):
|
||||
"""
|
||||
计算每个说话者在每种情感上的准确率
|
||||
|
||||
Args:
|
||||
y_true: 真实标签
|
||||
y_pred: 预测标签
|
||||
speaker_ids: 说话者ID数组
|
||||
emotions: 情感标签数组
|
||||
|
||||
Returns:
|
||||
说话者-情感准确率矩阵
|
||||
"""
|
||||
# 如果输入是独热编码,转换为类别索引
|
||||
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
|
||||
y_true = np.argmax(y_true, axis=1)
|
||||
|
||||
# 如果输入是预测概率,转换为类别索引
|
||||
if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
|
||||
y_pred = np.argmax(y_pred, axis=1)
|
||||
|
||||
# 获取所有说话者ID和情感类别
|
||||
unique_speakers = np.unique(speaker_ids)
|
||||
unique_emotions = np.unique(emotions)
|
||||
|
||||
# 创建准确率矩阵
|
||||
accuracy_matrix = np.zeros((len(unique_speakers), len(unique_emotions)))
|
||||
|
||||
# 计算每个说话者在每种情感上的准确率
|
||||
for i, speaker in enumerate(unique_speakers):
|
||||
for j, emotion in enumerate(unique_emotions):
|
||||
# 筛选特定说话者和情感的样本
|
||||
mask = (speaker_ids == speaker) & (emotions == emotion)
|
||||
if np.sum(mask) == 0:
|
||||
accuracy_matrix[i, j] = np.nan
|
||||
continue
|
||||
|
||||
# 计算准确率
|
||||
speaker_emotion_true = y_true[mask]
|
||||
speaker_emotion_pred = y_pred[mask]
|
||||
accuracy = accuracy_score(speaker_emotion_true, speaker_emotion_pred)
|
||||
accuracy_matrix[i, j] = accuracy
|
||||
|
||||
# 创建DataFrame
|
||||
accuracy_df = pd.DataFrame(
|
||||
accuracy_matrix,
|
||||
index=unique_speakers,
|
||||
columns=unique_emotions
|
||||
)
|
||||
|
||||
return accuracy_df
|
||||
279
utils/visualize.py
Normal file
279
utils/visualize.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
可视化工具模块,用于可视化训练历史、模型预测结果等
|
||||
"""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import librosa
|
||||
import librosa.display
|
||||
import seaborn as sns
|
||||
from sklearn.metrics import confusion_matrix
|
||||
import pandas as pd
|
||||
import matplotlib.font_manager as fm
|
||||
import os
|
||||
import platform
|
||||
|
||||
# 为中文显示设置字体
|
||||
def set_chinese_font():
|
||||
"""
|
||||
设置中文字体,支持Windows、MacOS和Linux
|
||||
"""
|
||||
system = platform.system()
|
||||
|
||||
if system == 'Windows':
|
||||
# Windows系统
|
||||
font_paths = [
|
||||
'C:/Windows/Fonts/simhei.ttf', # 黑体
|
||||
'C:/Windows/Fonts/simsun.ttc', # 宋体
|
||||
'C:/Windows/Fonts/simkai.ttf', # 楷体
|
||||
'C:/Windows/Fonts/msyh.ttc' # 微软雅黑
|
||||
]
|
||||
elif system == 'Darwin':
|
||||
# MacOS系统
|
||||
font_paths = [
|
||||
'/System/Library/Fonts/PingFang.ttc',
|
||||
'/Library/Fonts/Arial Unicode.ttf'
|
||||
]
|
||||
else:
|
||||
# Linux系统
|
||||
font_paths = [
|
||||
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
|
||||
'/usr/share/fonts/wqy-microhei/wqy-microhei.ttc'
|
||||
]
|
||||
|
||||
# 尝试加载字体
|
||||
font_path = None
|
||||
for path in font_paths:
|
||||
if os.path.exists(path):
|
||||
font_path = path
|
||||
break
|
||||
|
||||
if font_path:
|
||||
plt.rcParams['font.sans-serif'] = [fm.FontProperties(fname=font_path).get_name()]
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||
else:
|
||||
print("警告:未找到中文字体,可能无法正确显示中文")
|
||||
|
||||
def plot_training_history(history, save_path=None):
|
||||
"""
|
||||
绘制训练历史(准确率和损失曲线)
|
||||
|
||||
Args:
|
||||
history: 训练历史对象
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
||||
|
||||
# 绘制准确率
|
||||
ax1.plot(history.history['accuracy'], label='训练准确率')
|
||||
ax1.plot(history.history['val_accuracy'], label='验证准确率')
|
||||
ax1.set_title('模型准确率')
|
||||
ax1.set_ylabel('准确率')
|
||||
ax1.set_xlabel('轮次')
|
||||
ax1.legend(loc='lower right')
|
||||
ax1.grid(True)
|
||||
|
||||
# 绘制损失
|
||||
ax2.plot(history.history['loss'], label='训练损失')
|
||||
ax2.plot(history.history['val_loss'], label='验证损失')
|
||||
ax2.set_title('模型损失')
|
||||
ax2.set_ylabel('损失')
|
||||
ax2.set_xlabel('轮次')
|
||||
ax2.legend(loc='upper right')
|
||||
ax2.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred, classes, save_path=None, normalize=False, title='混淆矩阵'):
|
||||
"""
|
||||
绘制混淆矩阵
|
||||
|
||||
Args:
|
||||
y_true: 真实标签
|
||||
y_pred: 预测标签
|
||||
classes: 类别名称列表
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
normalize: 是否归一化
|
||||
title: 图表标题
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
# 计算混淆矩阵
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
|
||||
if normalize:
|
||||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# 使用seaborn绘制热图
|
||||
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues',
|
||||
xticklabels=classes, yticklabels=classes)
|
||||
|
||||
ax.set_ylabel('真实标签')
|
||||
ax.set_xlabel('预测标签')
|
||||
ax.set_title(title)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_audio_features(audio, sr, features=None, feature_names=None, save_path=None):
|
||||
"""
|
||||
可视化音频波形和特征
|
||||
|
||||
Args:
|
||||
audio: 音频数据
|
||||
sr: 采样率
|
||||
features: 特征字典
|
||||
feature_names: 要显示的特征名称列表
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
fig = plt.figure(figsize=(15, 10))
|
||||
|
||||
# 绘制波形
|
||||
ax1 = plt.subplot(3, 1, 1)
|
||||
librosa.display.waveshow(audio, sr=sr)
|
||||
ax1.set_title('音频波形')
|
||||
ax1.set_ylabel('振幅')
|
||||
|
||||
# 绘制声谱图
|
||||
ax2 = plt.subplot(3, 1, 2)
|
||||
D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max)
|
||||
librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sr)
|
||||
ax2.set_title('声谱图')
|
||||
plt.colorbar(format='%+2.0f dB')
|
||||
|
||||
# 绘制梅尔频谱图
|
||||
ax3 = plt.subplot(3, 1, 3)
|
||||
S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
|
||||
S_dB = librosa.power_to_db(S, ref=np.max)
|
||||
librosa.display.specshow(S_dB, y_axis='mel', x_axis='time', sr=sr)
|
||||
ax3.set_title('梅尔频谱图')
|
||||
plt.colorbar(format='%+2.0f dB')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
# 如果提供了特征,绘制特征条形图
|
||||
if features and feature_names:
|
||||
selected_features = {k: features[k] for k in feature_names if k in features}
|
||||
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.bar(selected_features.keys(), selected_features.values())
|
||||
plt.xticks(rotation=90)
|
||||
plt.title('选定特征值')
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
# 修改文件名,避免覆盖前一个图
|
||||
feature_save_path = save_path.replace('.', '_features.')
|
||||
plt.savefig(feature_save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_feature_importance(feature_names, importances, top_n=20, save_path=None):
|
||||
"""
|
||||
绘制特征重要性
|
||||
|
||||
Args:
|
||||
feature_names: 特征名称列表
|
||||
importances: 特征重要性列表
|
||||
top_n: 显示前N个最重要的特征
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
# 创建特征重要性数据框
|
||||
feature_importance_df = pd.DataFrame({
|
||||
'特征': feature_names,
|
||||
'重要性': importances
|
||||
})
|
||||
|
||||
# 排序并选择前N个特征
|
||||
feature_importance_df = feature_importance_df.sort_values('重要性', ascending=False).head(top_n)
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
sns.barplot(x='重要性', y='特征', data=feature_importance_df)
|
||||
plt.title(f'前{top_n}个最重要的特征')
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_prediction_result(audio_path, emotion, predicted_emotion, probabilities, emotion_labels, save_path=None):
|
||||
"""
|
||||
可视化预测结果
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
emotion: 真实情感标签
|
||||
predicted_emotion: 预测的情感标签
|
||||
probabilities: 预测概率
|
||||
emotion_labels: 情感标签列表
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
# 加载音频
|
||||
audio, sr = librosa.load(audio_path, sr=22050)
|
||||
|
||||
# 创建图形
|
||||
fig = plt.figure(figsize=(15, 10))
|
||||
|
||||
# 绘制波形
|
||||
ax1 = plt.subplot(2, 1, 1)
|
||||
librosa.display.waveshow(audio, sr=sr)
|
||||
ax1.set_title(f'音频波形 - 文件: {os.path.basename(audio_path)}')
|
||||
ax1.set_ylabel('振幅')
|
||||
|
||||
# 绘制预测结果条形图
|
||||
ax2 = plt.subplot(2, 1, 2)
|
||||
sns.barplot(x=emotion_labels, y=probabilities)
|
||||
ax2.set_title(f'情感预测结果 (真实: {emotion}, 预测: {predicted_emotion})')
|
||||
ax2.set_ylabel('预测概率')
|
||||
ax2.set_ylim([0, 1])
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加颜色高亮显示真实和预测标签
|
||||
for i, label in enumerate(emotion_labels):
|
||||
if label == emotion and label == predicted_emotion:
|
||||
# 真实标签和预测标签相同(正确预测)
|
||||
ax2.patches[i].set_facecolor('green')
|
||||
elif label == emotion:
|
||||
# 只是真实标签
|
||||
ax2.patches[i].set_facecolor('blue')
|
||||
elif label == predicted_emotion:
|
||||
# 只是预测标签(错误预测)
|
||||
ax2.patches[i].set_facecolor('red')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user