Files
yuyinfenxi/train.py
2025-07-02 13:54:05 +08:00

332 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型训练脚本
用于训练多语言语音情感识别模型
"""
import os
import argparse
import numpy as np
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
import matplotlib.pyplot as plt
import pickle
import time
from tqdm import tqdm
import sys
# 导入自定义模块
from data_utils.data_loader import load_all_data, prepare_data
from data_utils.feature_extractor import (
extract_features_batch, features_to_matrix,
normalize_features, reshape_for_lstm
)
from models.lstm_model import EmotionModel
from utils.visualize import plot_training_history, plot_confusion_matrix
from utils.evaluation import evaluate_model, print_evaluation_results
def parse_args():
"""
解析命令行参数
"""
parser = argparse.ArgumentParser(description='训练多语言语音情感识别模型')
# 数据路径参数
parser.add_argument('--casia_path', type=str, default='./CAISA',
help='CASIA数据集路径')
parser.add_argument('--savee_path', type=str, default='./SAVEE',
help='SAVEE数据集路径')
parser.add_argument('--ravdess_path', type=str, default='./RAVDESS',
help='RAVDESS数据集路径')
# 模型参数
parser.add_argument('--lstm_units', type=int, default=128,
help='LSTM层的单元数')
parser.add_argument('--dropout', type=float, default=0.3,
help='Dropout比率')
parser.add_argument('--regularization', type=float, default=0.001,
help='L2正则化系数')
parser.add_argument('--learning_rate', type=float, default=0.001,
help='学习率')
parser.add_argument('--batch_size', type=int, default=32,
help='批量大小')
parser.add_argument('--epochs', type=int, default=100,
help='训练轮数')
# 其他参数
parser.add_argument('--include_language', action='store_true',
help='是否将语言作为额外特征')
parser.add_argument('--output_dir', type=str, default='./output',
help='输出目录')
parser.add_argument('--model_name', type=str, default='emotion_model',
help='模型名称')
parser.add_argument('--seed', type=int, default=42,
help='随机种子')
args = parser.parse_args()
return args
def verify_paths(args):
"""
验证数据路径是否存在
Args:
args: 命令行参数
Returns:
bool: 所有路径都存在返回True否则返回False
"""
all_paths_exist = True
# 检查CASIA路径
if not os.path.exists(args.casia_path):
print(f"警告: CASIA数据路径不存在: {args.casia_path}")
all_paths_exist = False
# 检查SAVEE路径
if not os.path.exists(args.savee_path):
print(f"警告: SAVEE数据路径不存在: {args.savee_path}")
all_paths_exist = False
# 检查RAVDESS路径
if not os.path.exists(args.ravdess_path):
print(f"警告: RAVDESS数据路径不存在: {args.ravdess_path}")
all_paths_exist = False
if not all_paths_exist:
print("\n请确保以下数据集中至少有一个可用:")
print(f"1. CASIA (中文): {args.casia_path}")
print(f"2. SAVEE (英文): {args.savee_path}")
print(f"3. RAVDESS (英文): {args.ravdess_path}")
print("\n要修正路径,请使用以下参数:")
print("--casia_path /正确/的/CASIA/路径")
print("--savee_path /正确/的/SAVEE/路径")
print("--ravdess_path /正确/的/RAVDESS/路径")
return all_paths_exist
def train_model(args):
"""
训练模型的主函数
Args:
args: 命令行参数
"""
# 设置随机种子,确保结果可重复
np.random.seed(args.seed)
tf.random.set_seed(args.seed)
# 验证路径
paths_ok = verify_paths(args)
if not paths_ok:
print("警告: 一些数据集路径不存在,但会尝试继续处理可用的数据集...")
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
model_dir = os.path.join(args.output_dir, args.model_name)
os.makedirs(model_dir, exist_ok=True)
try:
# 加载数据
print("加载数据...")
# 选择所有数据集共有的六种情感
selected_emotions = ['angry', 'fear', 'happy', 'neutral', 'sad', 'surprise']
X, y_emotion, y_language = load_all_data(
args.casia_path,
args.savee_path,
args.ravdess_path,
selected_emotions=selected_emotions
)
if not X:
raise ValueError("没有加载到任何有效数据。请检查数据路径是否正确,以及音频文件格式是否支持。")
# 划分数据集
print("划分数据集...")
X_train, y_emotion_train, y_language_train, \
X_val, y_emotion_val, y_language_val, \
X_test, y_emotion_test, y_language_test = prepare_data(X, y_emotion, y_language)
# 提取特征
print("提取特征...")
print("从训练集提取特征...")
features_train = extract_features_batch(X_train)
print("从验证集提取特征...")
features_val = extract_features_batch(X_val)
print("从测试集提取特征...")
features_test = extract_features_batch(X_test)
# 转换为特征矩阵
print("转换为特征矩阵...")
X_train_matrix, feature_names = features_to_matrix(features_train)
X_val_matrix, _ = features_to_matrix(features_val)
X_test_matrix, _ = features_to_matrix(features_test)
# 保存特征名称
with open(os.path.join(model_dir, 'feature_names.pkl'), 'wb') as f:
pickle.dump(feature_names, f)
# 标准化特征
print("标准化特征...")
X_train_norm, X_val_norm, X_test_norm = normalize_features(
X_train_matrix, X_val_matrix, X_test_matrix
)
# 重塑数据为LSTM输入格式
print("重塑数据为LSTM输入格式...")
X_train_reshaped = reshape_for_lstm(X_train_norm)
X_val_reshaped = reshape_for_lstm(X_val_norm)
X_test_reshaped = reshape_for_lstm(X_test_norm)
# 对情感标签进行编码
print("编码标签...")
emotion_encoder = LabelEncoder()
emotion_encoder.fit(selected_emotions) # 确保所有类别都有
y_emotion_train_encoded = emotion_encoder.transform(y_emotion_train)
y_emotion_val_encoded = emotion_encoder.transform(y_emotion_val)
y_emotion_test_encoded = emotion_encoder.transform(y_emotion_test)
# 转换为独热编码
y_emotion_train_categorical = to_categorical(y_emotion_train_encoded)
y_emotion_val_categorical = to_categorical(y_emotion_val_encoded)
y_emotion_test_categorical = to_categorical(y_emotion_test_encoded)
# 保存标签编码器
with open(os.path.join(model_dir, 'emotion_encoder.pkl'), 'wb') as f:
pickle.dump(emotion_encoder, f)
# 如果包含语言特征,对语言标签进行编码
if args.include_language:
language_encoder = LabelEncoder()
language_encoder.fit(y_language_train + y_language_val + y_language_test)
y_language_train_encoded = language_encoder.transform(y_language_train)
y_language_val_encoded = language_encoder.transform(y_language_val)
y_language_test_encoded = language_encoder.transform(y_language_test)
y_language_train_categorical = to_categorical(y_language_train_encoded)
y_language_val_categorical = to_categorical(y_language_val_encoded)
y_language_test_categorical = to_categorical(y_language_test_encoded)
# 保存语言编码器
with open(os.path.join(model_dir, 'language_encoder.pkl'), 'wb') as f:
pickle.dump(language_encoder, f)
# 构建模型
print("构建模型...")
time_steps = X_train_reshaped.shape[1] # 时间步数
features = X_train_reshaped.shape[2] # 特征数
num_emotions = len(emotion_encoder.classes_)
if args.include_language:
num_languages = len(language_encoder.classes_)
model = EmotionModel(
input_shape=(time_steps, features),
num_emotions=num_emotions,
num_languages=num_languages,
include_language=True
)
else:
model = EmotionModel(
input_shape=(time_steps, features),
num_emotions=num_emotions
)
model.build_model(
lstm_units=args.lstm_units,
dropout_rate=args.dropout,
regularization_rate=args.regularization
)
model.compile_model(learning_rate=args.learning_rate)
# 输出模型摘要
model.model.summary()
# 训练模型
print("训练模型...")
checkpoint_path = os.path.join(model_dir, 'best_model_weights.h5')
start_time = time.time()
if args.include_language:
history = model.train(
X_train_reshaped, y_emotion_train_categorical,
X_val_reshaped, y_emotion_val_categorical,
epochs=args.epochs,
batch_size=args.batch_size,
checkpoint_path=checkpoint_path,
language_train=y_language_train_categorical,
language_val=y_language_val_categorical
)
else:
history = model.train(
X_train_reshaped, y_emotion_train_categorical,
X_val_reshaped, y_emotion_val_categorical,
epochs=args.epochs,
batch_size=args.batch_size,
checkpoint_path=checkpoint_path
)
training_time = time.time() - start_time
print(f"训练完成! 耗时: {training_time:.2f}")
# 保存完整模型
model.save(os.path.join(model_dir, 'emotion_model.h5'))
# 绘制训练历史
history_plot_path = os.path.join(model_dir, 'training_history.png')
plot_training_history(history, save_path=history_plot_path)
# 评估模型
print("评估模型...")
if args.include_language:
y_pred = model.predict([X_test_reshaped, y_language_test_categorical])
else:
y_pred = model.predict(X_test_reshaped)
emotion_classes = emotion_encoder.classes_
# 计算评估指标
results = evaluate_model(y_emotion_test_categorical, y_pred, emotion_classes)
print_evaluation_results(results)
# 绘制混淆矩阵
cm_plot_path = os.path.join(model_dir, 'confusion_matrix.png')
y_true = np.argmax(y_emotion_test_categorical, axis=1)
y_pred_classes = np.argmax(y_pred, axis=1)
plot_confusion_matrix(y_true, y_pred_classes, emotion_classes,
save_path=cm_plot_path, normalize=True)
# 保存测试结果
np.save(os.path.join(model_dir, 'y_true.npy'), y_true)
np.save(os.path.join(model_dir, 'y_pred.npy'), y_pred)
# 保存训练配置
config = vars(args)
config['num_features'] = len(feature_names)
config['num_emotions'] = num_emotions
config['training_time'] = training_time
config['accuracy'] = float(results['accuracy'])
with open(os.path.join(model_dir, 'config.pkl'), 'wb') as f:
pickle.dump(config, f)
print(f"模型和结果已保存到 {model_dir}")
return model, history, results
except Exception as e:
print(f"训练过程中出现错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
args = parse_args()
train_model(args)