332 lines
12 KiB
Python
332 lines
12 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
模型训练脚本
|
||
用于训练多语言语音情感识别模型
|
||
"""
|
||
|
||
import os
|
||
import argparse
|
||
import numpy as np
|
||
from sklearn.preprocessing import LabelEncoder
|
||
from tensorflow.keras.utils import to_categorical
|
||
import tensorflow as tf
|
||
import matplotlib.pyplot as plt
|
||
import pickle
|
||
import time
|
||
from tqdm import tqdm
|
||
import sys
|
||
|
||
# 导入自定义模块
|
||
from data_utils.data_loader import load_all_data, prepare_data
|
||
from data_utils.feature_extractor import (
|
||
extract_features_batch, features_to_matrix,
|
||
normalize_features, reshape_for_lstm
|
||
)
|
||
from models.lstm_model import EmotionModel
|
||
from utils.visualize import plot_training_history, plot_confusion_matrix
|
||
from utils.evaluation import evaluate_model, print_evaluation_results
|
||
|
||
def parse_args():
|
||
"""
|
||
解析命令行参数
|
||
"""
|
||
parser = argparse.ArgumentParser(description='训练多语言语音情感识别模型')
|
||
|
||
# 数据路径参数
|
||
parser.add_argument('--casia_path', type=str, default='./CAISA',
|
||
help='CASIA数据集路径')
|
||
parser.add_argument('--savee_path', type=str, default='./SAVEE',
|
||
help='SAVEE数据集路径')
|
||
parser.add_argument('--ravdess_path', type=str, default='./RAVDESS',
|
||
help='RAVDESS数据集路径')
|
||
|
||
# 模型参数
|
||
parser.add_argument('--lstm_units', type=int, default=128,
|
||
help='LSTM层的单元数')
|
||
parser.add_argument('--dropout', type=float, default=0.3,
|
||
help='Dropout比率')
|
||
parser.add_argument('--regularization', type=float, default=0.001,
|
||
help='L2正则化系数')
|
||
parser.add_argument('--learning_rate', type=float, default=0.001,
|
||
help='学习率')
|
||
parser.add_argument('--batch_size', type=int, default=32,
|
||
help='批量大小')
|
||
parser.add_argument('--epochs', type=int, default=100,
|
||
help='训练轮数')
|
||
|
||
# 其他参数
|
||
parser.add_argument('--include_language', action='store_true',
|
||
help='是否将语言作为额外特征')
|
||
parser.add_argument('--output_dir', type=str, default='./output',
|
||
help='输出目录')
|
||
parser.add_argument('--model_name', type=str, default='emotion_model',
|
||
help='模型名称')
|
||
parser.add_argument('--seed', type=int, default=42,
|
||
help='随机种子')
|
||
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
def verify_paths(args):
|
||
"""
|
||
验证数据路径是否存在
|
||
|
||
Args:
|
||
args: 命令行参数
|
||
|
||
Returns:
|
||
bool: 所有路径都存在返回True,否则返回False
|
||
"""
|
||
all_paths_exist = True
|
||
|
||
# 检查CASIA路径
|
||
if not os.path.exists(args.casia_path):
|
||
print(f"警告: CASIA数据路径不存在: {args.casia_path}")
|
||
all_paths_exist = False
|
||
|
||
# 检查SAVEE路径
|
||
if not os.path.exists(args.savee_path):
|
||
print(f"警告: SAVEE数据路径不存在: {args.savee_path}")
|
||
all_paths_exist = False
|
||
|
||
# 检查RAVDESS路径
|
||
if not os.path.exists(args.ravdess_path):
|
||
print(f"警告: RAVDESS数据路径不存在: {args.ravdess_path}")
|
||
all_paths_exist = False
|
||
|
||
if not all_paths_exist:
|
||
print("\n请确保以下数据集中至少有一个可用:")
|
||
print(f"1. CASIA (中文): {args.casia_path}")
|
||
print(f"2. SAVEE (英文): {args.savee_path}")
|
||
print(f"3. RAVDESS (英文): {args.ravdess_path}")
|
||
print("\n要修正路径,请使用以下参数:")
|
||
print("--casia_path /正确/的/CASIA/路径")
|
||
print("--savee_path /正确/的/SAVEE/路径")
|
||
print("--ravdess_path /正确/的/RAVDESS/路径")
|
||
|
||
return all_paths_exist
|
||
|
||
def train_model(args):
|
||
"""
|
||
训练模型的主函数
|
||
|
||
Args:
|
||
args: 命令行参数
|
||
"""
|
||
# 设置随机种子,确保结果可重复
|
||
np.random.seed(args.seed)
|
||
tf.random.set_seed(args.seed)
|
||
|
||
# 验证路径
|
||
paths_ok = verify_paths(args)
|
||
if not paths_ok:
|
||
print("警告: 一些数据集路径不存在,但会尝试继续处理可用的数据集...")
|
||
|
||
# 创建输出目录
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
model_dir = os.path.join(args.output_dir, args.model_name)
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
|
||
try:
|
||
# 加载数据
|
||
print("加载数据...")
|
||
# 选择所有数据集共有的六种情感
|
||
selected_emotions = ['angry', 'fear', 'happy', 'neutral', 'sad', 'surprise']
|
||
|
||
X, y_emotion, y_language = load_all_data(
|
||
args.casia_path,
|
||
args.savee_path,
|
||
args.ravdess_path,
|
||
selected_emotions=selected_emotions
|
||
)
|
||
|
||
if not X:
|
||
raise ValueError("没有加载到任何有效数据。请检查数据路径是否正确,以及音频文件格式是否支持。")
|
||
|
||
# 划分数据集
|
||
print("划分数据集...")
|
||
X_train, y_emotion_train, y_language_train, \
|
||
X_val, y_emotion_val, y_language_val, \
|
||
X_test, y_emotion_test, y_language_test = prepare_data(X, y_emotion, y_language)
|
||
|
||
# 提取特征
|
||
print("提取特征...")
|
||
print("从训练集提取特征...")
|
||
features_train = extract_features_batch(X_train)
|
||
print("从验证集提取特征...")
|
||
features_val = extract_features_batch(X_val)
|
||
print("从测试集提取特征...")
|
||
features_test = extract_features_batch(X_test)
|
||
|
||
# 转换为特征矩阵
|
||
print("转换为特征矩阵...")
|
||
X_train_matrix, feature_names = features_to_matrix(features_train)
|
||
X_val_matrix, _ = features_to_matrix(features_val)
|
||
X_test_matrix, _ = features_to_matrix(features_test)
|
||
|
||
# 保存特征名称
|
||
with open(os.path.join(model_dir, 'feature_names.pkl'), 'wb') as f:
|
||
pickle.dump(feature_names, f)
|
||
|
||
# 标准化特征
|
||
print("标准化特征...")
|
||
X_train_norm, X_val_norm, X_test_norm = normalize_features(
|
||
X_train_matrix, X_val_matrix, X_test_matrix
|
||
)
|
||
|
||
# 重塑数据为LSTM输入格式
|
||
print("重塑数据为LSTM输入格式...")
|
||
X_train_reshaped = reshape_for_lstm(X_train_norm)
|
||
X_val_reshaped = reshape_for_lstm(X_val_norm)
|
||
X_test_reshaped = reshape_for_lstm(X_test_norm)
|
||
|
||
# 对情感标签进行编码
|
||
print("编码标签...")
|
||
emotion_encoder = LabelEncoder()
|
||
emotion_encoder.fit(selected_emotions) # 确保所有类别都有
|
||
|
||
y_emotion_train_encoded = emotion_encoder.transform(y_emotion_train)
|
||
y_emotion_val_encoded = emotion_encoder.transform(y_emotion_val)
|
||
y_emotion_test_encoded = emotion_encoder.transform(y_emotion_test)
|
||
|
||
# 转换为独热编码
|
||
y_emotion_train_categorical = to_categorical(y_emotion_train_encoded)
|
||
y_emotion_val_categorical = to_categorical(y_emotion_val_encoded)
|
||
y_emotion_test_categorical = to_categorical(y_emotion_test_encoded)
|
||
|
||
# 保存标签编码器
|
||
with open(os.path.join(model_dir, 'emotion_encoder.pkl'), 'wb') as f:
|
||
pickle.dump(emotion_encoder, f)
|
||
|
||
# 如果包含语言特征,对语言标签进行编码
|
||
if args.include_language:
|
||
language_encoder = LabelEncoder()
|
||
language_encoder.fit(y_language_train + y_language_val + y_language_test)
|
||
|
||
y_language_train_encoded = language_encoder.transform(y_language_train)
|
||
y_language_val_encoded = language_encoder.transform(y_language_val)
|
||
y_language_test_encoded = language_encoder.transform(y_language_test)
|
||
|
||
y_language_train_categorical = to_categorical(y_language_train_encoded)
|
||
y_language_val_categorical = to_categorical(y_language_val_encoded)
|
||
y_language_test_categorical = to_categorical(y_language_test_encoded)
|
||
|
||
# 保存语言编码器
|
||
with open(os.path.join(model_dir, 'language_encoder.pkl'), 'wb') as f:
|
||
pickle.dump(language_encoder, f)
|
||
|
||
# 构建模型
|
||
print("构建模型...")
|
||
time_steps = X_train_reshaped.shape[1] # 时间步数
|
||
features = X_train_reshaped.shape[2] # 特征数
|
||
num_emotions = len(emotion_encoder.classes_)
|
||
|
||
if args.include_language:
|
||
num_languages = len(language_encoder.classes_)
|
||
model = EmotionModel(
|
||
input_shape=(time_steps, features),
|
||
num_emotions=num_emotions,
|
||
num_languages=num_languages,
|
||
include_language=True
|
||
)
|
||
else:
|
||
model = EmotionModel(
|
||
input_shape=(time_steps, features),
|
||
num_emotions=num_emotions
|
||
)
|
||
|
||
model.build_model(
|
||
lstm_units=args.lstm_units,
|
||
dropout_rate=args.dropout,
|
||
regularization_rate=args.regularization
|
||
)
|
||
|
||
model.compile_model(learning_rate=args.learning_rate)
|
||
|
||
# 输出模型摘要
|
||
model.model.summary()
|
||
|
||
# 训练模型
|
||
print("训练模型...")
|
||
checkpoint_path = os.path.join(model_dir, 'best_model_weights.h5')
|
||
|
||
start_time = time.time()
|
||
|
||
if args.include_language:
|
||
history = model.train(
|
||
X_train_reshaped, y_emotion_train_categorical,
|
||
X_val_reshaped, y_emotion_val_categorical,
|
||
epochs=args.epochs,
|
||
batch_size=args.batch_size,
|
||
checkpoint_path=checkpoint_path,
|
||
language_train=y_language_train_categorical,
|
||
language_val=y_language_val_categorical
|
||
)
|
||
else:
|
||
history = model.train(
|
||
X_train_reshaped, y_emotion_train_categorical,
|
||
X_val_reshaped, y_emotion_val_categorical,
|
||
epochs=args.epochs,
|
||
batch_size=args.batch_size,
|
||
checkpoint_path=checkpoint_path
|
||
)
|
||
|
||
training_time = time.time() - start_time
|
||
print(f"训练完成! 耗时: {training_time:.2f}秒")
|
||
|
||
# 保存完整模型
|
||
model.save(os.path.join(model_dir, 'emotion_model.h5'))
|
||
|
||
# 绘制训练历史
|
||
history_plot_path = os.path.join(model_dir, 'training_history.png')
|
||
plot_training_history(history, save_path=history_plot_path)
|
||
|
||
# 评估模型
|
||
print("评估模型...")
|
||
if args.include_language:
|
||
y_pred = model.predict([X_test_reshaped, y_language_test_categorical])
|
||
else:
|
||
y_pred = model.predict(X_test_reshaped)
|
||
|
||
emotion_classes = emotion_encoder.classes_
|
||
|
||
# 计算评估指标
|
||
results = evaluate_model(y_emotion_test_categorical, y_pred, emotion_classes)
|
||
print_evaluation_results(results)
|
||
|
||
# 绘制混淆矩阵
|
||
cm_plot_path = os.path.join(model_dir, 'confusion_matrix.png')
|
||
y_true = np.argmax(y_emotion_test_categorical, axis=1)
|
||
y_pred_classes = np.argmax(y_pred, axis=1)
|
||
plot_confusion_matrix(y_true, y_pred_classes, emotion_classes,
|
||
save_path=cm_plot_path, normalize=True)
|
||
|
||
# 保存测试结果
|
||
np.save(os.path.join(model_dir, 'y_true.npy'), y_true)
|
||
np.save(os.path.join(model_dir, 'y_pred.npy'), y_pred)
|
||
|
||
# 保存训练配置
|
||
config = vars(args)
|
||
config['num_features'] = len(feature_names)
|
||
config['num_emotions'] = num_emotions
|
||
config['training_time'] = training_time
|
||
config['accuracy'] = float(results['accuracy'])
|
||
|
||
with open(os.path.join(model_dir, 'config.pkl'), 'wb') as f:
|
||
pickle.dump(config, f)
|
||
|
||
print(f"模型和结果已保存到 {model_dir}")
|
||
|
||
return model, history, results
|
||
|
||
except Exception as e:
|
||
print(f"训练过程中出现错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
if __name__ == "__main__":
|
||
args = parse_args()
|
||
train_model(args) |