This commit is contained in:
lzc
2025-07-02 13:54:05 +08:00
commit 4b3870440c
6351 changed files with 282880 additions and 0 deletions

4
utils/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
"""
工具函数包
包含可视化和评估功能
"""

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

243
utils/evaluation.py Normal file
View File

@@ -0,0 +1,243 @@
"""
评估工具模块,用于评估模型性能
"""
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd
def evaluate_model(y_true, y_pred, class_names=None):
"""
评估模型性能
Args:
y_true: 真实标签(独热编码形式)
y_pred: 预测概率
class_names: 类别名称列表
Returns:
评估结果字典
"""
# 转换为类别索引
y_true_classes = np.argmax(y_true, axis=1)
y_pred_classes = np.argmax(y_pred, axis=1)
# 计算混淆矩阵
cm = confusion_matrix(y_true_classes, y_pred_classes)
# 计算准确率
accuracy = accuracy_score(y_true_classes, y_pred_classes)
# 计算精确率、召回率和F1分数
precision = precision_score(y_true_classes, y_pred_classes, average='weighted', zero_division=0)
recall = recall_score(y_true_classes, y_pred_classes, average='weighted', zero_division=0)
f1 = f1_score(y_true_classes, y_pred_classes, average='weighted', zero_division=0)
# 计算每个类别的指标
class_precision = precision_score(y_true_classes, y_pred_classes, average=None, zero_division=0)
class_recall = recall_score(y_true_classes, y_pred_classes, average=None, zero_division=0)
class_f1 = f1_score(y_true_classes, y_pred_classes, average=None, zero_division=0)
# 整理每个类别的指标
class_metrics = {}
# 如果提供了类别名称,使用类别名称作为键
if class_names is not None and len(class_names) > 0:
for i, name in enumerate(class_names):
if i < len(class_precision):
class_metrics[name] = {
'precision': class_precision[i],
'recall': class_recall[i],
'f1': class_f1[i]
}
else:
# 否则使用类别索引作为键
for i in range(len(class_precision)):
class_metrics[f'class_{i}'] = {
'precision': class_precision[i],
'recall': class_recall[i],
'f1': class_f1[i]
}
# 返回结果字典
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'confusion_matrix': cm,
'class_metrics': class_metrics
}
def print_evaluation_results(results, class_names=None):
"""
打印评估结果
Args:
results: 评估结果字典
class_names: 类别名称列表
"""
print("=" * 50)
print("模型评估结果:")
print("=" * 50)
print(f"准确率 (Accuracy): {results['accuracy']:.4f}")
print(f"精确率 (Precision): {results['precision']:.4f}")
print(f"召回率 (Recall): {results['recall']:.4f}")
print(f"F1分数 (F1): {results['f1']:.4f}")
print("\n混淆矩阵:")
print(results['confusion_matrix'])
if 'class_metrics' in results:
print("\n每个类别的指标:")
for class_name, metrics in results['class_metrics'].items():
print(f"{class_name}:")
print(f" 精确率 (Precision): {metrics['precision']:.4f}")
print(f" 召回率 (Recall): {metrics['recall']:.4f}")
print(f" F1分数 (F1): {metrics['f1']:.4f}")
print("=" * 50)
def evaluate_by_language(y_true, y_pred, language, class_names=None):
"""
按语言分类评估模型性能
Args:
y_true: 真实标签(数值或独热编码)
y_pred: 预测标签(数值或预测概率)
language: 语言标签数组
class_names: 类别名称列表
Returns:
按语言分类的评估结果字典
"""
# 如果输入是独热编码,转换为类别索引
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 如果输入是预测概率,转换为类别索引
if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
y_pred = np.argmax(y_pred, axis=1)
# 获取所有语言类型
unique_languages = np.unique(language)
results = {}
# 对每种语言分别计算评估指标
for lang in unique_languages:
# 筛选特定语言的样本
mask = (language == lang)
if np.sum(mask) == 0:
continue
lang_y_true = y_true[mask]
lang_y_pred = y_pred[mask]
# 计算该语言的评估指标
lang_results = evaluate_model(lang_y_true, lang_y_pred, class_names)
results[lang] = lang_results
return results
def print_evaluation_by_language(results, languages=None):
"""
打印按语言分类的评估结果
Args:
results: 按语言分类的评估结果字典
languages: 语言名称字典,将语言代码映射到语言名称
"""
if languages is None:
languages = {
'zh': '中文',
'en': '英文'
}
for lang, lang_results in results.items():
lang_name = languages.get(lang, lang)
print("=" * 50)
print(f"{lang_name}数据评估结果:")
print("=" * 50)
print_evaluation_results(lang_results)
def get_top_n_predictions(probabilities, class_names, n=3):
"""
获取概率最高的前N个预测结果
Args:
probabilities: 预测概率数组
class_names: 类别名称列表
n: 获取前N个结果
Returns:
包含前N个预测及其概率的字典
"""
# 找到前N个最高概率的索引
top_n_indices = np.argsort(probabilities)[-n:][::-1]
# 获取对应的类别名称和概率
top_n_classes = [class_names[i] for i in top_n_indices]
top_n_probs = [probabilities[i] for i in top_n_indices]
# 构建结果
result = {
'classes': top_n_classes,
'probabilities': top_n_probs
}
return result
def get_emotion_accuracy_by_speaker(y_true, y_pred, speaker_ids, emotions):
"""
计算每个说话者在每种情感上的准确率
Args:
y_true: 真实标签
y_pred: 预测标签
speaker_ids: 说话者ID数组
emotions: 情感标签数组
Returns:
说话者-情感准确率矩阵
"""
# 如果输入是独热编码,转换为类别索引
if len(y_true.shape) > 1 and y_true.shape[1] > 1:
y_true = np.argmax(y_true, axis=1)
# 如果输入是预测概率,转换为类别索引
if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
y_pred = np.argmax(y_pred, axis=1)
# 获取所有说话者ID和情感类别
unique_speakers = np.unique(speaker_ids)
unique_emotions = np.unique(emotions)
# 创建准确率矩阵
accuracy_matrix = np.zeros((len(unique_speakers), len(unique_emotions)))
# 计算每个说话者在每种情感上的准确率
for i, speaker in enumerate(unique_speakers):
for j, emotion in enumerate(unique_emotions):
# 筛选特定说话者和情感的样本
mask = (speaker_ids == speaker) & (emotions == emotion)
if np.sum(mask) == 0:
accuracy_matrix[i, j] = np.nan
continue
# 计算准确率
speaker_emotion_true = y_true[mask]
speaker_emotion_pred = y_pred[mask]
accuracy = accuracy_score(speaker_emotion_true, speaker_emotion_pred)
accuracy_matrix[i, j] = accuracy
# 创建DataFrame
accuracy_df = pd.DataFrame(
accuracy_matrix,
index=unique_speakers,
columns=unique_emotions
)
return accuracy_df

279
utils/visualize.py Normal file
View File

@@ -0,0 +1,279 @@
"""
可视化工具模块,用于可视化训练历史、模型预测结果等
"""
import matplotlib.pyplot as plt
import numpy as np
import librosa
import librosa.display
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pandas as pd
import matplotlib.font_manager as fm
import os
import platform
# 为中文显示设置字体
def set_chinese_font():
"""
设置中文字体支持Windows、MacOS和Linux
"""
system = platform.system()
if system == 'Windows':
# Windows系统
font_paths = [
'C:/Windows/Fonts/simhei.ttf', # 黑体
'C:/Windows/Fonts/simsun.ttc', # 宋体
'C:/Windows/Fonts/simkai.ttf', # 楷体
'C:/Windows/Fonts/msyh.ttc' # 微软雅黑
]
elif system == 'Darwin':
# MacOS系统
font_paths = [
'/System/Library/Fonts/PingFang.ttc',
'/Library/Fonts/Arial Unicode.ttf'
]
else:
# Linux系统
font_paths = [
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
'/usr/share/fonts/wqy-microhei/wqy-microhei.ttc'
]
# 尝试加载字体
font_path = None
for path in font_paths:
if os.path.exists(path):
font_path = path
break
if font_path:
plt.rcParams['font.sans-serif'] = [fm.FontProperties(fname=font_path).get_name()]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
else:
print("警告:未找到中文字体,可能无法正确显示中文")
def plot_training_history(history, save_path=None):
"""
绘制训练历史(准确率和损失曲线)
Args:
history: 训练历史对象
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# 绘制准确率
ax1.plot(history.history['accuracy'], label='训练准确率')
ax1.plot(history.history['val_accuracy'], label='验证准确率')
ax1.set_title('模型准确率')
ax1.set_ylabel('准确率')
ax1.set_xlabel('轮次')
ax1.legend(loc='lower right')
ax1.grid(True)
# 绘制损失
ax2.plot(history.history['loss'], label='训练损失')
ax2.plot(history.history['val_loss'], label='验证损失')
ax2.set_title('模型损失')
ax2.set_ylabel('损失')
ax2.set_xlabel('轮次')
ax2.legend(loc='upper right')
ax2.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
def plot_confusion_matrix(y_true, y_pred, classes, save_path=None, normalize=False, title='混淆矩阵'):
"""
绘制混淆矩阵
Args:
y_true: 真实标签
y_pred: 预测标签
classes: 类别名称列表
save_path: 保存路径如果为None则显示图形
normalize: 是否归一化
title: 图表标题
"""
set_chinese_font()
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10, 8))
# 使用seaborn绘制热图
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues',
xticklabels=classes, yticklabels=classes)
ax.set_ylabel('真实标签')
ax.set_xlabel('预测标签')
ax.set_title(title)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
def plot_audio_features(audio, sr, features=None, feature_names=None, save_path=None):
"""
可视化音频波形和特征
Args:
audio: 音频数据
sr: 采样率
features: 特征字典
feature_names: 要显示的特征名称列表
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
fig = plt.figure(figsize=(15, 10))
# 绘制波形
ax1 = plt.subplot(3, 1, 1)
librosa.display.waveshow(audio, sr=sr)
ax1.set_title('音频波形')
ax1.set_ylabel('振幅')
# 绘制声谱图
ax2 = plt.subplot(3, 1, 2)
D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max)
librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sr)
ax2.set_title('声谱图')
plt.colorbar(format='%+2.0f dB')
# 绘制梅尔频谱图
ax3 = plt.subplot(3, 1, 3)
S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
S_dB = librosa.power_to_db(S, ref=np.max)
librosa.display.specshow(S_dB, y_axis='mel', x_axis='time', sr=sr)
ax3.set_title('梅尔频谱图')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
# 如果提供了特征,绘制特征条形图
if features and feature_names:
selected_features = {k: features[k] for k in feature_names if k in features}
plt.figure(figsize=(12, 6))
plt.bar(selected_features.keys(), selected_features.values())
plt.xticks(rotation=90)
plt.title('选定特征值')
plt.tight_layout()
if save_path:
# 修改文件名,避免覆盖前一个图
feature_save_path = save_path.replace('.', '_features.')
plt.savefig(feature_save_path)
plt.close()
else:
plt.show()
def plot_feature_importance(feature_names, importances, top_n=20, save_path=None):
"""
绘制特征重要性
Args:
feature_names: 特征名称列表
importances: 特征重要性列表
top_n: 显示前N个最重要的特征
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
# 创建特征重要性数据框
feature_importance_df = pd.DataFrame({
'特征': feature_names,
'重要性': importances
})
# 排序并选择前N个特征
feature_importance_df = feature_importance_df.sort_values('重要性', ascending=False).head(top_n)
plt.figure(figsize=(12, 8))
sns.barplot(x='重要性', y='特征', data=feature_importance_df)
plt.title(f'{top_n}个最重要的特征')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
def plot_prediction_result(audio_path, emotion, predicted_emotion, probabilities, emotion_labels, save_path=None):
"""
可视化预测结果
Args:
audio_path: 音频文件路径
emotion: 真实情感标签
predicted_emotion: 预测的情感标签
probabilities: 预测概率
emotion_labels: 情感标签列表
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
# 加载音频
audio, sr = librosa.load(audio_path, sr=22050)
# 创建图形
fig = plt.figure(figsize=(15, 10))
# 绘制波形
ax1 = plt.subplot(2, 1, 1)
librosa.display.waveshow(audio, sr=sr)
ax1.set_title(f'音频波形 - 文件: {os.path.basename(audio_path)}')
ax1.set_ylabel('振幅')
# 绘制预测结果条形图
ax2 = plt.subplot(2, 1, 2)
sns.barplot(x=emotion_labels, y=probabilities)
ax2.set_title(f'情感预测结果 (真实: {emotion}, 预测: {predicted_emotion})')
ax2.set_ylabel('预测概率')
ax2.set_ylim([0, 1])
plt.xticks(rotation=45)
# 添加颜色高亮显示真实和预测标签
for i, label in enumerate(emotion_labels):
if label == emotion and label == predicted_emotion:
# 真实标签和预测标签相同(正确预测)
ax2.patches[i].set_facecolor('green')
elif label == emotion:
# 只是真实标签
ax2.patches[i].set_facecolor('blue')
elif label == predicted_emotion:
# 只是预测标签(错误预测)
ax2.patches[i].set_facecolor('red')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()