""" 可视化工具模块,用于可视化训练历史、模型预测结果等 """ import matplotlib.pyplot as plt import numpy as np import librosa import librosa.display import seaborn as sns from sklearn.metrics import confusion_matrix import pandas as pd import matplotlib.font_manager as fm import os import platform # 为中文显示设置字体 def set_chinese_font(): """ 设置中文字体,支持Windows、MacOS和Linux """ system = platform.system() if system == 'Windows': # Windows系统 font_paths = [ 'C:/Windows/Fonts/simhei.ttf', # 黑体 'C:/Windows/Fonts/simsun.ttc', # 宋体 'C:/Windows/Fonts/simkai.ttf', # 楷体 'C:/Windows/Fonts/msyh.ttc' # 微软雅黑 ] elif system == 'Darwin': # MacOS系统 font_paths = [ '/System/Library/Fonts/PingFang.ttc', '/Library/Fonts/Arial Unicode.ttf' ] else: # Linux系统 font_paths = [ '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', '/usr/share/fonts/wqy-microhei/wqy-microhei.ttc' ] # 尝试加载字体 font_path = None for path in font_paths: if os.path.exists(path): font_path = path break if font_path: plt.rcParams['font.sans-serif'] = [fm.FontProperties(fname=font_path).get_name()] plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 else: print("警告:未找到中文字体,可能无法正确显示中文") def plot_training_history(history, save_path=None): """ 绘制训练历史(准确率和损失曲线) Args: history: 训练历史对象 save_path: 保存路径,如果为None则显示图形 """ set_chinese_font() fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) # 绘制准确率 ax1.plot(history.history['accuracy'], label='训练准确率') ax1.plot(history.history['val_accuracy'], label='验证准确率') ax1.set_title('模型准确率') ax1.set_ylabel('准确率') ax1.set_xlabel('轮次') ax1.legend(loc='lower right') ax1.grid(True) # 绘制损失 ax2.plot(history.history['loss'], label='训练损失') ax2.plot(history.history['val_loss'], label='验证损失') ax2.set_title('模型损失') ax2.set_ylabel('损失') ax2.set_xlabel('轮次') ax2.legend(loc='upper right') ax2.grid(True) plt.tight_layout() if save_path: plt.savefig(save_path) plt.close() else: plt.show() def plot_confusion_matrix(y_true, y_pred, classes, save_path=None, normalize=False, title='混淆矩阵'): """ 绘制混淆矩阵 Args: y_true: 真实标签 y_pred: 预测标签 classes: 类别名称列表 save_path: 保存路径,如果为None则显示图形 normalize: 是否归一化 title: 图表标题 """ set_chinese_font() # 计算混淆矩阵 cm = confusion_matrix(y_true, y_pred) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] fig, ax = plt.subplots(figsize=(10, 8)) # 使用seaborn绘制热图 sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues', xticklabels=classes, yticklabels=classes) ax.set_ylabel('真实标签') ax.set_xlabel('预测标签') ax.set_title(title) plt.tight_layout() if save_path: plt.savefig(save_path) plt.close() else: plt.show() def plot_audio_features(audio, sr, features=None, feature_names=None, save_path=None): """ 可视化音频波形和特征 Args: audio: 音频数据 sr: 采样率 features: 特征字典 feature_names: 要显示的特征名称列表 save_path: 保存路径,如果为None则显示图形 """ set_chinese_font() fig = plt.figure(figsize=(15, 10)) # 绘制波形 ax1 = plt.subplot(3, 1, 1) librosa.display.waveshow(audio, sr=sr) ax1.set_title('音频波形') ax1.set_ylabel('振幅') # 绘制声谱图 ax2 = plt.subplot(3, 1, 2) D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max) librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sr) ax2.set_title('声谱图') plt.colorbar(format='%+2.0f dB') # 绘制梅尔频谱图 ax3 = plt.subplot(3, 1, 3) S = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128) S_dB = librosa.power_to_db(S, ref=np.max) librosa.display.specshow(S_dB, y_axis='mel', x_axis='time', sr=sr) ax3.set_title('梅尔频谱图') plt.colorbar(format='%+2.0f dB') plt.tight_layout() if save_path: plt.savefig(save_path) plt.close() else: plt.show() # 如果提供了特征,绘制特征条形图 if features and feature_names: selected_features = {k: features[k] for k in feature_names if k in features} plt.figure(figsize=(12, 6)) plt.bar(selected_features.keys(), selected_features.values()) plt.xticks(rotation=90) plt.title('选定特征值') plt.tight_layout() if save_path: # 修改文件名,避免覆盖前一个图 feature_save_path = save_path.replace('.', '_features.') plt.savefig(feature_save_path) plt.close() else: plt.show() def plot_feature_importance(feature_names, importances, top_n=20, save_path=None): """ 绘制特征重要性 Args: feature_names: 特征名称列表 importances: 特征重要性列表 top_n: 显示前N个最重要的特征 save_path: 保存路径,如果为None则显示图形 """ set_chinese_font() # 创建特征重要性数据框 feature_importance_df = pd.DataFrame({ '特征': feature_names, '重要性': importances }) # 排序并选择前N个特征 feature_importance_df = feature_importance_df.sort_values('重要性', ascending=False).head(top_n) plt.figure(figsize=(12, 8)) sns.barplot(x='重要性', y='特征', data=feature_importance_df) plt.title(f'前{top_n}个最重要的特征') plt.tight_layout() if save_path: plt.savefig(save_path) plt.close() else: plt.show() def plot_prediction_result(audio_path, emotion, predicted_emotion, probabilities, emotion_labels, save_path=None): """ 可视化预测结果 Args: audio_path: 音频文件路径 emotion: 真实情感标签 predicted_emotion: 预测的情感标签 probabilities: 预测概率 emotion_labels: 情感标签列表 save_path: 保存路径,如果为None则显示图形 """ set_chinese_font() # 加载音频 audio, sr = librosa.load(audio_path, sr=22050) # 创建图形 fig = plt.figure(figsize=(15, 10)) # 绘制波形 ax1 = plt.subplot(2, 1, 1) librosa.display.waveshow(audio, sr=sr) ax1.set_title(f'音频波形 - 文件: {os.path.basename(audio_path)}') ax1.set_ylabel('振幅') # 绘制预测结果条形图 ax2 = plt.subplot(2, 1, 2) sns.barplot(x=emotion_labels, y=probabilities) ax2.set_title(f'情感预测结果 (真实: {emotion}, 预测: {predicted_emotion})') ax2.set_ylabel('预测概率') ax2.set_ylim([0, 1]) plt.xticks(rotation=45) # 添加颜色高亮显示真实和预测标签 for i, label in enumerate(emotion_labels): if label == emotion and label == predicted_emotion: # 真实标签和预测标签相同(正确预测) ax2.patches[i].set_facecolor('green') elif label == emotion: # 只是真实标签 ax2.patches[i].set_facecolor('blue') elif label == predicted_emotion: # 只是预测标签(错误预测) ax2.patches[i].set_facecolor('red') plt.tight_layout() if save_path: plt.savefig(save_path) plt.close() else: plt.show()