This commit is contained in:
lzc
2025-07-02 13:54:05 +08:00
commit 4b3870440c
6351 changed files with 282880 additions and 0 deletions

332
train.py Normal file
View File

@@ -0,0 +1,332 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型训练脚本
用于训练多语言语音情感识别模型
"""
import os
import argparse
import numpy as np
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
import matplotlib.pyplot as plt
import pickle
import time
from tqdm import tqdm
import sys
# 导入自定义模块
from data_utils.data_loader import load_all_data, prepare_data
from data_utils.feature_extractor import (
extract_features_batch, features_to_matrix,
normalize_features, reshape_for_lstm
)
from models.lstm_model import EmotionModel
from utils.visualize import plot_training_history, plot_confusion_matrix
from utils.evaluation import evaluate_model, print_evaluation_results
def parse_args():
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(description='训练多语言语音情感识别模型')
# 数据路径参数
parser.add_argument('--casia_path', type=str, default='./CAISA',
help='CASIA数据集路径')
parser.add_argument('--savee_path', type=str, default='./SAVEE',
help='SAVEE数据集路径')
parser.add_argument('--ravdess_path', type=str, default='./RAVDESS',
help='RAVDESS数据集路径')
# 模型参数
parser.add_argument('--lstm_units', type=int, default=128,
help='LSTM层的单元数')
parser.add_argument('--dropout', type=float, default=0.3,
help='Dropout比率')
parser.add_argument('--regularization', type=float, default=0.001,
help='L2正则化系数')
parser.add_argument('--learning_rate', type=float, default=0.001,
help='学习率')
parser.add_argument('--batch_size', type=int, default=32,
help='批量大小')
parser.add_argument('--epochs', type=int, default=100,
help='训练轮数')
# 其他参数
parser.add_argument('--include_language', action='store_true',
help='是否将语言作为额外特征')
parser.add_argument('--output_dir', type=str, default='./output',
help='输出目录')
parser.add_argument('--model_name', type=str, default='emotion_model',
help='模型名称')
parser.add_argument('--seed', type=int, default=42,
help='随机种子')
args = parser.parse_args()
return args
def verify_paths(args):
"""
验证数据路径是否存在
Args:
args: 命令行参数
Returns:
bool: 所有路径都存在返回True否则返回False
"""
all_paths_exist = True
# 检查CASIA路径
if not os.path.exists(args.casia_path):
print(f"警告: CASIA数据路径不存在: {args.casia_path}")
all_paths_exist = False
# 检查SAVEE路径
if not os.path.exists(args.savee_path):
print(f"警告: SAVEE数据路径不存在: {args.savee_path}")
all_paths_exist = False
# 检查RAVDESS路径
if not os.path.exists(args.ravdess_path):
print(f"警告: RAVDESS数据路径不存在: {args.ravdess_path}")
all_paths_exist = False
if not all_paths_exist:
print("\n请确保以下数据集中至少有一个可用:")
print(f"1. CASIA (中文): {args.casia_path}")
print(f"2. SAVEE (英文): {args.savee_path}")
print(f"3. RAVDESS (英文): {args.ravdess_path}")
print("\n要修正路径,请使用以下参数:")
print("--casia_path /正确/的/CASIA/路径")
print("--savee_path /正确/的/SAVEE/路径")
print("--ravdess_path /正确/的/RAVDESS/路径")
return all_paths_exist
def train_model(args):
"""
训练模型的主函数
Args:
args: 命令行参数
"""
# 设置随机种子,确保结果可重复
np.random.seed(args.seed)
tf.random.set_seed(args.seed)
# 验证路径
paths_ok = verify_paths(args)
if not paths_ok:
print("警告: 一些数据集路径不存在,但会尝试继续处理可用的数据集...")
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
model_dir = os.path.join(args.output_dir, args.model_name)
os.makedirs(model_dir, exist_ok=True)
try:
# 加载数据
print("加载数据...")
# 选择所有数据集共有的六种情感
selected_emotions = ['angry', 'fear', 'happy', 'neutral', 'sad', 'surprise']
X, y_emotion, y_language = load_all_data(
args.casia_path,
args.savee_path,
args.ravdess_path,
selected_emotions=selected_emotions
)
if not X:
raise ValueError("没有加载到任何有效数据。请检查数据路径是否正确,以及音频文件格式是否支持。")
# 划分数据集
print("划分数据集...")
X_train, y_emotion_train, y_language_train, \
X_val, y_emotion_val, y_language_val, \
X_test, y_emotion_test, y_language_test = prepare_data(X, y_emotion, y_language)
# 提取特征
print("提取特征...")
print("从训练集提取特征...")
features_train = extract_features_batch(X_train)
print("从验证集提取特征...")
features_val = extract_features_batch(X_val)
print("从测试集提取特征...")
features_test = extract_features_batch(X_test)
# 转换为特征矩阵
print("转换为特征矩阵...")
X_train_matrix, feature_names = features_to_matrix(features_train)
X_val_matrix, _ = features_to_matrix(features_val)
X_test_matrix, _ = features_to_matrix(features_test)
# 保存特征名称
with open(os.path.join(model_dir, 'feature_names.pkl'), 'wb') as f:
pickle.dump(feature_names, f)
# 标准化特征
print("标准化特征...")
X_train_norm, X_val_norm, X_test_norm = normalize_features(
X_train_matrix, X_val_matrix, X_test_matrix
)
# 重塑数据为LSTM输入格式
print("重塑数据为LSTM输入格式...")
X_train_reshaped = reshape_for_lstm(X_train_norm)
X_val_reshaped = reshape_for_lstm(X_val_norm)
X_test_reshaped = reshape_for_lstm(X_test_norm)
# 对情感标签进行编码
print("编码标签...")
emotion_encoder = LabelEncoder()
emotion_encoder.fit(selected_emotions) # 确保所有类别都有
y_emotion_train_encoded = emotion_encoder.transform(y_emotion_train)
y_emotion_val_encoded = emotion_encoder.transform(y_emotion_val)
y_emotion_test_encoded = emotion_encoder.transform(y_emotion_test)
# 转换为独热编码
y_emotion_train_categorical = to_categorical(y_emotion_train_encoded)
y_emotion_val_categorical = to_categorical(y_emotion_val_encoded)
y_emotion_test_categorical = to_categorical(y_emotion_test_encoded)
# 保存标签编码器
with open(os.path.join(model_dir, 'emotion_encoder.pkl'), 'wb') as f:
pickle.dump(emotion_encoder, f)
# 如果包含语言特征,对语言标签进行编码
if args.include_language:
language_encoder = LabelEncoder()
language_encoder.fit(y_language_train + y_language_val + y_language_test)
y_language_train_encoded = language_encoder.transform(y_language_train)
y_language_val_encoded = language_encoder.transform(y_language_val)
y_language_test_encoded = language_encoder.transform(y_language_test)
y_language_train_categorical = to_categorical(y_language_train_encoded)
y_language_val_categorical = to_categorical(y_language_val_encoded)
y_language_test_categorical = to_categorical(y_language_test_encoded)
# 保存语言编码器
with open(os.path.join(model_dir, 'language_encoder.pkl'), 'wb') as f:
pickle.dump(language_encoder, f)
# 构建模型
print("构建模型...")
time_steps = X_train_reshaped.shape[1] # 时间步数
features = X_train_reshaped.shape[2] # 特征数
num_emotions = len(emotion_encoder.classes_)
if args.include_language:
num_languages = len(language_encoder.classes_)
model = EmotionModel(
input_shape=(time_steps, features),
num_emotions=num_emotions,
num_languages=num_languages,
include_language=True
)
else:
model = EmotionModel(
input_shape=(time_steps, features),
num_emotions=num_emotions
)
model.build_model(
lstm_units=args.lstm_units,
dropout_rate=args.dropout,
regularization_rate=args.regularization
)
model.compile_model(learning_rate=args.learning_rate)
# 输出模型摘要
model.model.summary()
# 训练模型
print("训练模型...")
checkpoint_path = os.path.join(model_dir, 'best_model_weights.h5')
start_time = time.time()
if args.include_language:
history = model.train(
X_train_reshaped, y_emotion_train_categorical,
X_val_reshaped, y_emotion_val_categorical,
epochs=args.epochs,
batch_size=args.batch_size,
checkpoint_path=checkpoint_path,
language_train=y_language_train_categorical,
language_val=y_language_val_categorical
)
else:
history = model.train(
X_train_reshaped, y_emotion_train_categorical,
X_val_reshaped, y_emotion_val_categorical,
epochs=args.epochs,
batch_size=args.batch_size,
checkpoint_path=checkpoint_path
)
training_time = time.time() - start_time
print(f"训练完成! 耗时: {training_time:.2f}")
# 保存完整模型
model.save(os.path.join(model_dir, 'emotion_model.h5'))
# 绘制训练历史
history_plot_path = os.path.join(model_dir, 'training_history.png')
plot_training_history(history, save_path=history_plot_path)
# 评估模型
print("评估模型...")
if args.include_language:
y_pred = model.predict([X_test_reshaped, y_language_test_categorical])
else:
y_pred = model.predict(X_test_reshaped)
emotion_classes = emotion_encoder.classes_
# 计算评估指标
results = evaluate_model(y_emotion_test_categorical, y_pred, emotion_classes)
print_evaluation_results(results)
# 绘制混淆矩阵
cm_plot_path = os.path.join(model_dir, 'confusion_matrix.png')
y_true = np.argmax(y_emotion_test_categorical, axis=1)
y_pred_classes = np.argmax(y_pred, axis=1)
plot_confusion_matrix(y_true, y_pred_classes, emotion_classes,
save_path=cm_plot_path, normalize=True)
# 保存测试结果
np.save(os.path.join(model_dir, 'y_true.npy'), y_true)
np.save(os.path.join(model_dir, 'y_pred.npy'), y_pred)
# 保存训练配置
config = vars(args)
config['num_features'] = len(feature_names)
config['num_emotions'] = num_emotions
config['training_time'] = training_time
config['accuracy'] = float(results['accuracy'])
with open(os.path.join(model_dir, 'config.pkl'), 'wb') as f:
pickle.dump(config, f)
print(f"模型和结果已保存到 {model_dir}")
return model, history, results
except Exception as e:
print(f"训练过程中出现错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
args = parse_args()
train_model(args)