This commit is contained in:
lzc
2025-07-02 13:54:05 +08:00
commit 4b3870440c
6351 changed files with 282880 additions and 0 deletions

279
utils/visualize.py Normal file
View File

@@ -0,0 +1,279 @@
"""
可视化工具模块,用于可视化训练历史、模型预测结果等
"""
import matplotlib.pyplot as plt
import numpy as np
import librosa
import librosa.display
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pandas as pd
import matplotlib.font_manager as fm
import os
import platform
# 为中文显示设置字体
def set_chinese_font():
"""
设置中文字体支持Windows、MacOS和Linux
"""
system = platform.system()
if system == 'Windows':
# Windows系统
font_paths = [
'C:/Windows/Fonts/simhei.ttf', # 黑体
'C:/Windows/Fonts/simsun.ttc', # 宋体
'C:/Windows/Fonts/simkai.ttf', # 楷体
'C:/Windows/Fonts/msyh.ttc' # 微软雅黑
]
elif system == 'Darwin':
# MacOS系统
font_paths = [
'/System/Library/Fonts/PingFang.ttc',
'/Library/Fonts/Arial Unicode.ttf'
]
else:
# Linux系统
font_paths = [
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
'/usr/share/fonts/wqy-microhei/wqy-microhei.ttc'
]
# 尝试加载字体
font_path = None
for path in font_paths:
if os.path.exists(path):
font_path = path
break
if font_path:
plt.rcParams['font.sans-serif'] = [fm.FontProperties(fname=font_path).get_name()]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
else:
print("警告:未找到中文字体,可能无法正确显示中文")
def plot_training_history(history, save_path=None):
"""
绘制训练历史(准确率和损失曲线)
Args:
history: 训练历史对象
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# 绘制准确率
ax1.plot(history.history['accuracy'], label='训练准确率')
ax1.plot(history.history['val_accuracy'], label='验证准确率')
ax1.set_title('模型准确率')
ax1.set_ylabel('准确率')
ax1.set_xlabel('轮次')
ax1.legend(loc='lower right')
ax1.grid(True)
# 绘制损失
ax2.plot(history.history['loss'], label='训练损失')
ax2.plot(history.history['val_loss'], label='验证损失')
ax2.set_title('模型损失')
ax2.set_ylabel('损失')
ax2.set_xlabel('轮次')
ax2.legend(loc='upper right')
ax2.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
def plot_confusion_matrix(y_true, y_pred, classes, save_path=None, normalize=False, title='混淆矩阵'):
"""
绘制混淆矩阵
Args:
y_true: 真实标签
y_pred: 预测标签
classes: 类别名称列表
save_path: 保存路径如果为None则显示图形
normalize: 是否归一化
title: 图表标题
"""
set_chinese_font()
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10, 8))
# 使用seaborn绘制热图
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues',
xticklabels=classes, yticklabels=classes)
ax.set_ylabel('真实标签')
ax.set_xlabel('预测标签')
ax.set_title(title)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
def plot_audio_features(audio, sr, features=None, feature_names=None, save_path=None):
"""
可视化音频波形和特征
Args:
audio: 音频数据
sr: 采样率
features: 特征字典
feature_names: 要显示的特征名称列表
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
fig = plt.figure(figsize=(15, 10))
# 绘制波形
ax1 = plt.subplot(3, 1, 1)
librosa.display.waveshow(audio, sr=sr)
ax1.set_title('音频波形')
ax1.set_ylabel('振幅')
# 绘制声谱图
ax2 = plt.subplot(3, 1, 2)
D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max)
librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sr)
ax2.set_title('声谱图')
plt.colorbar(format='%+2.0f dB')
# 绘制梅尔频谱图
ax3 = plt.subplot(3, 1, 3)
S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
S_dB = librosa.power_to_db(S, ref=np.max)
librosa.display.specshow(S_dB, y_axis='mel', x_axis='time', sr=sr)
ax3.set_title('梅尔频谱图')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
# 如果提供了特征,绘制特征条形图
if features and feature_names:
selected_features = {k: features[k] for k in feature_names if k in features}
plt.figure(figsize=(12, 6))
plt.bar(selected_features.keys(), selected_features.values())
plt.xticks(rotation=90)
plt.title('选定特征值')
plt.tight_layout()
if save_path:
# 修改文件名,避免覆盖前一个图
feature_save_path = save_path.replace('.', '_features.')
plt.savefig(feature_save_path)
plt.close()
else:
plt.show()
def plot_feature_importance(feature_names, importances, top_n=20, save_path=None):
"""
绘制特征重要性
Args:
feature_names: 特征名称列表
importances: 特征重要性列表
top_n: 显示前N个最重要的特征
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
# 创建特征重要性数据框
feature_importance_df = pd.DataFrame({
'特征': feature_names,
'重要性': importances
})
# 排序并选择前N个特征
feature_importance_df = feature_importance_df.sort_values('重要性', ascending=False).head(top_n)
plt.figure(figsize=(12, 8))
sns.barplot(x='重要性', y='特征', data=feature_importance_df)
plt.title(f'{top_n}个最重要的特征')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
def plot_prediction_result(audio_path, emotion, predicted_emotion, probabilities, emotion_labels, save_path=None):
"""
可视化预测结果
Args:
audio_path: 音频文件路径
emotion: 真实情感标签
predicted_emotion: 预测的情感标签
probabilities: 预测概率
emotion_labels: 情感标签列表
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
# 加载音频
audio, sr = librosa.load(audio_path, sr=22050)
# 创建图形
fig = plt.figure(figsize=(15, 10))
# 绘制波形
ax1 = plt.subplot(2, 1, 1)
librosa.display.waveshow(audio, sr=sr)
ax1.set_title(f'音频波形 - 文件: {os.path.basename(audio_path)}')
ax1.set_ylabel('振幅')
# 绘制预测结果条形图
ax2 = plt.subplot(2, 1, 2)
sns.barplot(x=emotion_labels, y=probabilities)
ax2.set_title(f'情感预测结果 (真实: {emotion}, 预测: {predicted_emotion})')
ax2.set_ylabel('预测概率')
ax2.set_ylim([0, 1])
plt.xticks(rotation=45)
# 添加颜色高亮显示真实和预测标签
for i, label in enumerate(emotion_labels):
if label == emotion and label == predicted_emotion:
# 真实标签和预测标签相同(正确预测)
ax2.patches[i].set_facecolor('green')
elif label == emotion:
# 只是真实标签
ax2.patches[i].set_facecolor('blue')
elif label == predicted_emotion:
# 只是预测标签(错误预测)
ax2.patches[i].set_facecolor('red')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()