chuban
This commit is contained in:
279
utils/visualize.py
Normal file
279
utils/visualize.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
可视化工具模块,用于可视化训练历史、模型预测结果等
|
||||
"""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import librosa
|
||||
import librosa.display
|
||||
import seaborn as sns
|
||||
from sklearn.metrics import confusion_matrix
|
||||
import pandas as pd
|
||||
import matplotlib.font_manager as fm
|
||||
import os
|
||||
import platform
|
||||
|
||||
# 为中文显示设置字体
|
||||
def set_chinese_font():
|
||||
"""
|
||||
设置中文字体,支持Windows、MacOS和Linux
|
||||
"""
|
||||
system = platform.system()
|
||||
|
||||
if system == 'Windows':
|
||||
# Windows系统
|
||||
font_paths = [
|
||||
'C:/Windows/Fonts/simhei.ttf', # 黑体
|
||||
'C:/Windows/Fonts/simsun.ttc', # 宋体
|
||||
'C:/Windows/Fonts/simkai.ttf', # 楷体
|
||||
'C:/Windows/Fonts/msyh.ttc' # 微软雅黑
|
||||
]
|
||||
elif system == 'Darwin':
|
||||
# MacOS系统
|
||||
font_paths = [
|
||||
'/System/Library/Fonts/PingFang.ttc',
|
||||
'/Library/Fonts/Arial Unicode.ttf'
|
||||
]
|
||||
else:
|
||||
# Linux系统
|
||||
font_paths = [
|
||||
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
|
||||
'/usr/share/fonts/wqy-microhei/wqy-microhei.ttc'
|
||||
]
|
||||
|
||||
# 尝试加载字体
|
||||
font_path = None
|
||||
for path in font_paths:
|
||||
if os.path.exists(path):
|
||||
font_path = path
|
||||
break
|
||||
|
||||
if font_path:
|
||||
plt.rcParams['font.sans-serif'] = [fm.FontProperties(fname=font_path).get_name()]
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
||||
else:
|
||||
print("警告:未找到中文字体,可能无法正确显示中文")
|
||||
|
||||
def plot_training_history(history, save_path=None):
|
||||
"""
|
||||
绘制训练历史(准确率和损失曲线)
|
||||
|
||||
Args:
|
||||
history: 训练历史对象
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
||||
|
||||
# 绘制准确率
|
||||
ax1.plot(history.history['accuracy'], label='训练准确率')
|
||||
ax1.plot(history.history['val_accuracy'], label='验证准确率')
|
||||
ax1.set_title('模型准确率')
|
||||
ax1.set_ylabel('准确率')
|
||||
ax1.set_xlabel('轮次')
|
||||
ax1.legend(loc='lower right')
|
||||
ax1.grid(True)
|
||||
|
||||
# 绘制损失
|
||||
ax2.plot(history.history['loss'], label='训练损失')
|
||||
ax2.plot(history.history['val_loss'], label='验证损失')
|
||||
ax2.set_title('模型损失')
|
||||
ax2.set_ylabel('损失')
|
||||
ax2.set_xlabel('轮次')
|
||||
ax2.legend(loc='upper right')
|
||||
ax2.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred, classes, save_path=None, normalize=False, title='混淆矩阵'):
|
||||
"""
|
||||
绘制混淆矩阵
|
||||
|
||||
Args:
|
||||
y_true: 真实标签
|
||||
y_pred: 预测标签
|
||||
classes: 类别名称列表
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
normalize: 是否归一化
|
||||
title: 图表标题
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
# 计算混淆矩阵
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
|
||||
if normalize:
|
||||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
|
||||
# 使用seaborn绘制热图
|
||||
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues',
|
||||
xticklabels=classes, yticklabels=classes)
|
||||
|
||||
ax.set_ylabel('真实标签')
|
||||
ax.set_xlabel('预测标签')
|
||||
ax.set_title(title)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_audio_features(audio, sr, features=None, feature_names=None, save_path=None):
|
||||
"""
|
||||
可视化音频波形和特征
|
||||
|
||||
Args:
|
||||
audio: 音频数据
|
||||
sr: 采样率
|
||||
features: 特征字典
|
||||
feature_names: 要显示的特征名称列表
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
fig = plt.figure(figsize=(15, 10))
|
||||
|
||||
# 绘制波形
|
||||
ax1 = plt.subplot(3, 1, 1)
|
||||
librosa.display.waveshow(audio, sr=sr)
|
||||
ax1.set_title('音频波形')
|
||||
ax1.set_ylabel('振幅')
|
||||
|
||||
# 绘制声谱图
|
||||
ax2 = plt.subplot(3, 1, 2)
|
||||
D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max)
|
||||
librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sr)
|
||||
ax2.set_title('声谱图')
|
||||
plt.colorbar(format='%+2.0f dB')
|
||||
|
||||
# 绘制梅尔频谱图
|
||||
ax3 = plt.subplot(3, 1, 3)
|
||||
S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
|
||||
S_dB = librosa.power_to_db(S, ref=np.max)
|
||||
librosa.display.specshow(S_dB, y_axis='mel', x_axis='time', sr=sr)
|
||||
ax3.set_title('梅尔频谱图')
|
||||
plt.colorbar(format='%+2.0f dB')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
# 如果提供了特征,绘制特征条形图
|
||||
if features and feature_names:
|
||||
selected_features = {k: features[k] for k in feature_names if k in features}
|
||||
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.bar(selected_features.keys(), selected_features.values())
|
||||
plt.xticks(rotation=90)
|
||||
plt.title('选定特征值')
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
# 修改文件名,避免覆盖前一个图
|
||||
feature_save_path = save_path.replace('.', '_features.')
|
||||
plt.savefig(feature_save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_feature_importance(feature_names, importances, top_n=20, save_path=None):
|
||||
"""
|
||||
绘制特征重要性
|
||||
|
||||
Args:
|
||||
feature_names: 特征名称列表
|
||||
importances: 特征重要性列表
|
||||
top_n: 显示前N个最重要的特征
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
# 创建特征重要性数据框
|
||||
feature_importance_df = pd.DataFrame({
|
||||
'特征': feature_names,
|
||||
'重要性': importances
|
||||
})
|
||||
|
||||
# 排序并选择前N个特征
|
||||
feature_importance_df = feature_importance_df.sort_values('重要性', ascending=False).head(top_n)
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
sns.barplot(x='重要性', y='特征', data=feature_importance_df)
|
||||
plt.title(f'前{top_n}个最重要的特征')
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def plot_prediction_result(audio_path, emotion, predicted_emotion, probabilities, emotion_labels, save_path=None):
|
||||
"""
|
||||
可视化预测结果
|
||||
|
||||
Args:
|
||||
audio_path: 音频文件路径
|
||||
emotion: 真实情感标签
|
||||
predicted_emotion: 预测的情感标签
|
||||
probabilities: 预测概率
|
||||
emotion_labels: 情感标签列表
|
||||
save_path: 保存路径,如果为None则显示图形
|
||||
"""
|
||||
set_chinese_font()
|
||||
|
||||
# 加载音频
|
||||
audio, sr = librosa.load(audio_path, sr=22050)
|
||||
|
||||
# 创建图形
|
||||
fig = plt.figure(figsize=(15, 10))
|
||||
|
||||
# 绘制波形
|
||||
ax1 = plt.subplot(2, 1, 1)
|
||||
librosa.display.waveshow(audio, sr=sr)
|
||||
ax1.set_title(f'音频波形 - 文件: {os.path.basename(audio_path)}')
|
||||
ax1.set_ylabel('振幅')
|
||||
|
||||
# 绘制预测结果条形图
|
||||
ax2 = plt.subplot(2, 1, 2)
|
||||
sns.barplot(x=emotion_labels, y=probabilities)
|
||||
ax2.set_title(f'情感预测结果 (真实: {emotion}, 预测: {predicted_emotion})')
|
||||
ax2.set_ylabel('预测概率')
|
||||
ax2.set_ylim([0, 1])
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
# 添加颜色高亮显示真实和预测标签
|
||||
for i, label in enumerate(emotion_labels):
|
||||
if label == emotion and label == predicted_emotion:
|
||||
# 真实标签和预测标签相同(正确预测)
|
||||
ax2.patches[i].set_facecolor('green')
|
||||
elif label == emotion:
|
||||
# 只是真实标签
|
||||
ax2.patches[i].set_facecolor('blue')
|
||||
elif label == predicted_emotion:
|
||||
# 只是预测标签(错误预测)
|
||||
ax2.patches[i].set_facecolor('red')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user