#!/usr/bin/env python # -*- coding: utf-8 -*- """ 模型训练脚本 用于训练多语言语音情感识别模型 """ import os import argparse import numpy as np from sklearn.preprocessing import LabelEncoder from tensorflow.keras.utils import to_categorical import tensorflow as tf import matplotlib.pyplot as plt import pickle import time from tqdm import tqdm import sys # 导入自定义模块 from data_utils.data_loader import load_all_data, prepare_data from data_utils.feature_extractor import ( extract_features_batch, features_to_matrix, normalize_features, reshape_for_lstm ) from models.lstm_model import EmotionModel from utils.visualize import plot_training_history, plot_confusion_matrix from utils.evaluation import evaluate_model, print_evaluation_results def parse_args(): """ 解析命令行参数 """ parser = argparse.ArgumentParser(description='训练多语言语音情感识别模型') # 数据路径参数 parser.add_argument('--casia_path', type=str, default='./CAISA', help='CASIA数据集路径') parser.add_argument('--savee_path', type=str, default='./SAVEE', help='SAVEE数据集路径') parser.add_argument('--ravdess_path', type=str, default='./RAVDESS', help='RAVDESS数据集路径') # 模型参数 parser.add_argument('--lstm_units', type=int, default=128, help='LSTM层的单元数') parser.add_argument('--dropout', type=float, default=0.3, help='Dropout比率') parser.add_argument('--regularization', type=float, default=0.001, help='L2正则化系数') parser.add_argument('--learning_rate', type=float, default=0.001, help='学习率') parser.add_argument('--batch_size', type=int, default=32, help='批量大小') parser.add_argument('--epochs', type=int, default=100, help='训练轮数') # 其他参数 parser.add_argument('--include_language', action='store_true', help='是否将语言作为额外特征') parser.add_argument('--output_dir', type=str, default='./output', help='输出目录') parser.add_argument('--model_name', type=str, default='emotion_model', help='模型名称') parser.add_argument('--seed', type=int, default=42, help='随机种子') args = parser.parse_args() return args def verify_paths(args): """ 验证数据路径是否存在 Args: args: 命令行参数 Returns: bool: 所有路径都存在返回True,否则返回False """ all_paths_exist = True # 检查CASIA路径 if not os.path.exists(args.casia_path): print(f"警告: CASIA数据路径不存在: {args.casia_path}") all_paths_exist = False # 检查SAVEE路径 if not os.path.exists(args.savee_path): print(f"警告: SAVEE数据路径不存在: {args.savee_path}") all_paths_exist = False # 检查RAVDESS路径 if not os.path.exists(args.ravdess_path): print(f"警告: RAVDESS数据路径不存在: {args.ravdess_path}") all_paths_exist = False if not all_paths_exist: print("\n请确保以下数据集中至少有一个可用:") print(f"1. CASIA (中文): {args.casia_path}") print(f"2. SAVEE (英文): {args.savee_path}") print(f"3. RAVDESS (英文): {args.ravdess_path}") print("\n要修正路径,请使用以下参数:") print("--casia_path /正确/的/CASIA/路径") print("--savee_path /正确/的/SAVEE/路径") print("--ravdess_path /正确/的/RAVDESS/路径") return all_paths_exist def train_model(args): """ 训练模型的主函数 Args: args: 命令行参数 """ # 设置随机种子,确保结果可重复 np.random.seed(args.seed) tf.random.set_seed(args.seed) # 验证路径 paths_ok = verify_paths(args) if not paths_ok: print("警告: 一些数据集路径不存在,但会尝试继续处理可用的数据集...") # 创建输出目录 os.makedirs(args.output_dir, exist_ok=True) model_dir = os.path.join(args.output_dir, args.model_name) os.makedirs(model_dir, exist_ok=True) try: # 加载数据 print("加载数据...") # 选择所有数据集共有的六种情感 selected_emotions = ['angry', 'fear', 'happy', 'neutral', 'sad', 'surprise'] X, y_emotion, y_language = load_all_data( args.casia_path, args.savee_path, args.ravdess_path, selected_emotions=selected_emotions ) if not X: raise ValueError("没有加载到任何有效数据。请检查数据路径是否正确,以及音频文件格式是否支持。") # 划分数据集 print("划分数据集...") X_train, y_emotion_train, y_language_train, \ X_val, y_emotion_val, y_language_val, \ X_test, y_emotion_test, y_language_test = prepare_data(X, y_emotion, y_language) # 提取特征 print("提取特征...") print("从训练集提取特征...") features_train = extract_features_batch(X_train) print("从验证集提取特征...") features_val = extract_features_batch(X_val) print("从测试集提取特征...") features_test = extract_features_batch(X_test) # 转换为特征矩阵 print("转换为特征矩阵...") X_train_matrix, feature_names = features_to_matrix(features_train) X_val_matrix, _ = features_to_matrix(features_val) X_test_matrix, _ = features_to_matrix(features_test) # 保存特征名称 with open(os.path.join(model_dir, 'feature_names.pkl'), 'wb') as f: pickle.dump(feature_names, f) # 标准化特征 print("标准化特征...") X_train_norm, X_val_norm, X_test_norm = normalize_features( X_train_matrix, X_val_matrix, X_test_matrix ) # 重塑数据为LSTM输入格式 print("重塑数据为LSTM输入格式...") X_train_reshaped = reshape_for_lstm(X_train_norm) X_val_reshaped = reshape_for_lstm(X_val_norm) X_test_reshaped = reshape_for_lstm(X_test_norm) # 对情感标签进行编码 print("编码标签...") emotion_encoder = LabelEncoder() emotion_encoder.fit(selected_emotions) # 确保所有类别都有 y_emotion_train_encoded = emotion_encoder.transform(y_emotion_train) y_emotion_val_encoded = emotion_encoder.transform(y_emotion_val) y_emotion_test_encoded = emotion_encoder.transform(y_emotion_test) # 转换为独热编码 y_emotion_train_categorical = to_categorical(y_emotion_train_encoded) y_emotion_val_categorical = to_categorical(y_emotion_val_encoded) y_emotion_test_categorical = to_categorical(y_emotion_test_encoded) # 保存标签编码器 with open(os.path.join(model_dir, 'emotion_encoder.pkl'), 'wb') as f: pickle.dump(emotion_encoder, f) # 如果包含语言特征,对语言标签进行编码 if args.include_language: language_encoder = LabelEncoder() language_encoder.fit(y_language_train + y_language_val + y_language_test) y_language_train_encoded = language_encoder.transform(y_language_train) y_language_val_encoded = language_encoder.transform(y_language_val) y_language_test_encoded = language_encoder.transform(y_language_test) y_language_train_categorical = to_categorical(y_language_train_encoded) y_language_val_categorical = to_categorical(y_language_val_encoded) y_language_test_categorical = to_categorical(y_language_test_encoded) # 保存语言编码器 with open(os.path.join(model_dir, 'language_encoder.pkl'), 'wb') as f: pickle.dump(language_encoder, f) # 构建模型 print("构建模型...") time_steps = X_train_reshaped.shape[1] # 时间步数 features = X_train_reshaped.shape[2] # 特征数 num_emotions = len(emotion_encoder.classes_) if args.include_language: num_languages = len(language_encoder.classes_) model = EmotionModel( input_shape=(time_steps, features), num_emotions=num_emotions, num_languages=num_languages, include_language=True ) else: model = EmotionModel( input_shape=(time_steps, features), num_emotions=num_emotions ) model.build_model( lstm_units=args.lstm_units, dropout_rate=args.dropout, regularization_rate=args.regularization ) model.compile_model(learning_rate=args.learning_rate) # 输出模型摘要 model.model.summary() # 训练模型 print("训练模型...") checkpoint_path = os.path.join(model_dir, 'best_model_weights.h5') start_time = time.time() if args.include_language: history = model.train( X_train_reshaped, y_emotion_train_categorical, X_val_reshaped, y_emotion_val_categorical, epochs=args.epochs, batch_size=args.batch_size, checkpoint_path=checkpoint_path, language_train=y_language_train_categorical, language_val=y_language_val_categorical ) else: history = model.train( X_train_reshaped, y_emotion_train_categorical, X_val_reshaped, y_emotion_val_categorical, epochs=args.epochs, batch_size=args.batch_size, checkpoint_path=checkpoint_path ) training_time = time.time() - start_time print(f"训练完成! 耗时: {training_time:.2f}秒") # 保存完整模型 model.save(os.path.join(model_dir, 'emotion_model.h5')) # 绘制训练历史 history_plot_path = os.path.join(model_dir, 'training_history.png') plot_training_history(history, save_path=history_plot_path) # 评估模型 print("评估模型...") if args.include_language: y_pred = model.predict([X_test_reshaped, y_language_test_categorical]) else: y_pred = model.predict(X_test_reshaped) emotion_classes = emotion_encoder.classes_ # 计算评估指标 results = evaluate_model(y_emotion_test_categorical, y_pred, emotion_classes) print_evaluation_results(results) # 绘制混淆矩阵 cm_plot_path = os.path.join(model_dir, 'confusion_matrix.png') y_true = np.argmax(y_emotion_test_categorical, axis=1) y_pred_classes = np.argmax(y_pred, axis=1) plot_confusion_matrix(y_true, y_pred_classes, emotion_classes, save_path=cm_plot_path, normalize=True) # 保存测试结果 np.save(os.path.join(model_dir, 'y_true.npy'), y_true) np.save(os.path.join(model_dir, 'y_pred.npy'), y_pred) # 保存训练配置 config = vars(args) config['num_features'] = len(feature_names) config['num_emotions'] = num_emotions config['training_time'] = training_time config['accuracy'] = float(results['accuracy']) with open(os.path.join(model_dir, 'config.pkl'), 'wb') as f: pickle.dump(config, f) print(f"模型和结果已保存到 {model_dir}") return model, history, results except Exception as e: print(f"训练过程中出现错误: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": args = parse_args() train_model(args)