Files
yuyinfenxi/utils/visualize.py
2025-07-02 13:54:05 +08:00

279 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
可视化工具模块,用于可视化训练历史、模型预测结果等
"""
import matplotlib.pyplot as plt
import numpy as np
import librosa
import librosa.display
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pandas as pd
import matplotlib.font_manager as fm
import os
import platform
# 为中文显示设置字体
def set_chinese_font():
"""
设置中文字体支持Windows、MacOS和Linux
"""
system = platform.system()
if system == 'Windows':
# Windows系统
font_paths = [
'C:/Windows/Fonts/simhei.ttf', # 黑体
'C:/Windows/Fonts/simsun.ttc', # 宋体
'C:/Windows/Fonts/simkai.ttf', # 楷体
'C:/Windows/Fonts/msyh.ttc' # 微软雅黑
]
elif system == 'Darwin':
# MacOS系统
font_paths = [
'/System/Library/Fonts/PingFang.ttc',
'/Library/Fonts/Arial Unicode.ttf'
]
else:
# Linux系统
font_paths = [
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',
'/usr/share/fonts/wqy-microhei/wqy-microhei.ttc'
]
# 尝试加载字体
font_path = None
for path in font_paths:
if os.path.exists(path):
font_path = path
break
if font_path:
plt.rcParams['font.sans-serif'] = [fm.FontProperties(fname=font_path).get_name()]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
else:
print("警告:未找到中文字体,可能无法正确显示中文")
def plot_training_history(history, save_path=None):
"""
绘制训练历史(准确率和损失曲线)
Args:
history: 训练历史对象
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# 绘制准确率
ax1.plot(history.history['accuracy'], label='训练准确率')
ax1.plot(history.history['val_accuracy'], label='验证准确率')
ax1.set_title('模型准确率')
ax1.set_ylabel('准确率')
ax1.set_xlabel('轮次')
ax1.legend(loc='lower right')
ax1.grid(True)
# 绘制损失
ax2.plot(history.history['loss'], label='训练损失')
ax2.plot(history.history['val_loss'], label='验证损失')
ax2.set_title('模型损失')
ax2.set_ylabel('损失')
ax2.set_xlabel('轮次')
ax2.legend(loc='upper right')
ax2.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
def plot_confusion_matrix(y_true, y_pred, classes, save_path=None, normalize=False, title='混淆矩阵'):
"""
绘制混淆矩阵
Args:
y_true: 真实标签
y_pred: 预测标签
classes: 类别名称列表
save_path: 保存路径如果为None则显示图形
normalize: 是否归一化
title: 图表标题
"""
set_chinese_font()
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(10, 8))
# 使用seaborn绘制热图
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues',
xticklabels=classes, yticklabels=classes)
ax.set_ylabel('真实标签')
ax.set_xlabel('预测标签')
ax.set_title(title)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
def plot_audio_features(audio, sr, features=None, feature_names=None, save_path=None):
"""
可视化音频波形和特征
Args:
audio: 音频数据
sr: 采样率
features: 特征字典
feature_names: 要显示的特征名称列表
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
fig = plt.figure(figsize=(15, 10))
# 绘制波形
ax1 = plt.subplot(3, 1, 1)
librosa.display.waveshow(audio, sr=sr)
ax1.set_title('音频波形')
ax1.set_ylabel('振幅')
# 绘制声谱图
ax2 = plt.subplot(3, 1, 2)
D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max)
librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sr)
ax2.set_title('声谱图')
plt.colorbar(format='%+2.0f dB')
# 绘制梅尔频谱图
ax3 = plt.subplot(3, 1, 3)
S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
S_dB = librosa.power_to_db(S, ref=np.max)
librosa.display.specshow(S_dB, y_axis='mel', x_axis='time', sr=sr)
ax3.set_title('梅尔频谱图')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
# 如果提供了特征,绘制特征条形图
if features and feature_names:
selected_features = {k: features[k] for k in feature_names if k in features}
plt.figure(figsize=(12, 6))
plt.bar(selected_features.keys(), selected_features.values())
plt.xticks(rotation=90)
plt.title('选定特征值')
plt.tight_layout()
if save_path:
# 修改文件名,避免覆盖前一个图
feature_save_path = save_path.replace('.', '_features.')
plt.savefig(feature_save_path)
plt.close()
else:
plt.show()
def plot_feature_importance(feature_names, importances, top_n=20, save_path=None):
"""
绘制特征重要性
Args:
feature_names: 特征名称列表
importances: 特征重要性列表
top_n: 显示前N个最重要的特征
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
# 创建特征重要性数据框
feature_importance_df = pd.DataFrame({
'特征': feature_names,
'重要性': importances
})
# 排序并选择前N个特征
feature_importance_df = feature_importance_df.sort_values('重要性', ascending=False).head(top_n)
plt.figure(figsize=(12, 8))
sns.barplot(x='重要性', y='特征', data=feature_importance_df)
plt.title(f'{top_n}个最重要的特征')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()
def plot_prediction_result(audio_path, emotion, predicted_emotion, probabilities, emotion_labels, save_path=None):
"""
可视化预测结果
Args:
audio_path: 音频文件路径
emotion: 真实情感标签
predicted_emotion: 预测的情感标签
probabilities: 预测概率
emotion_labels: 情感标签列表
save_path: 保存路径如果为None则显示图形
"""
set_chinese_font()
# 加载音频
audio, sr = librosa.load(audio_path, sr=22050)
# 创建图形
fig = plt.figure(figsize=(15, 10))
# 绘制波形
ax1 = plt.subplot(2, 1, 1)
librosa.display.waveshow(audio, sr=sr)
ax1.set_title(f'音频波形 - 文件: {os.path.basename(audio_path)}')
ax1.set_ylabel('振幅')
# 绘制预测结果条形图
ax2 = plt.subplot(2, 1, 2)
sns.barplot(x=emotion_labels, y=probabilities)
ax2.set_title(f'情感预测结果 (真实: {emotion}, 预测: {predicted_emotion})')
ax2.set_ylabel('预测概率')
ax2.set_ylim([0, 1])
plt.xticks(rotation=45)
# 添加颜色高亮显示真实和预测标签
for i, label in enumerate(emotion_labels):
if label == emotion and label == predicted_emotion:
# 真实标签和预测标签相同(正确预测)
ax2.patches[i].set_facecolor('green')
elif label == emotion:
# 只是真实标签
ax2.patches[i].set_facecolor('blue')
elif label == predicted_emotion:
# 只是预测标签(错误预测)
ax2.patches[i].set_facecolor('red')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.close()
else:
plt.show()