278 lines
9.3 KiB
Python
278 lines
9.3 KiB
Python
"""
|
||
LSTM模型定义模块,用于构建语音情感识别模型
|
||
"""
|
||
|
||
import tensorflow as tf
|
||
from tensorflow.keras.models import Sequential, Model
|
||
from tensorflow.keras.layers import Input, Dense, LSTM, Dropout, BatchNormalization, Bidirectional
|
||
from tensorflow.keras.regularizers import l2
|
||
from tensorflow.keras.optimizers import Adam
|
||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
|
||
import numpy as np
|
||
import os
|
||
|
||
class EmotionModel:
|
||
"""
|
||
语音情感识别模型类
|
||
"""
|
||
|
||
def __init__(self, input_shape, num_emotions, num_languages=None, include_language=False):
|
||
"""
|
||
初始化模型
|
||
|
||
Args:
|
||
input_shape: 输入特征的形状,例如 (1, 60) 表示每个样本有1个时间步和60个特征
|
||
num_emotions: 情感类别数量
|
||
num_languages: 语言类别数量(如果include_language为True)
|
||
include_language: 是否将语言作为额外特征
|
||
"""
|
||
self.input_shape = input_shape
|
||
self.num_emotions = num_emotions
|
||
self.num_languages = num_languages
|
||
self.include_language = include_language
|
||
self.model = None
|
||
self.save_weights_only = True # 默认只保存权重
|
||
|
||
def build_model(self, lstm_units=128, dropout_rate=0.3, regularization_rate=0.001):
|
||
"""
|
||
构建LSTM模型
|
||
|
||
Args:
|
||
lstm_units: LSTM层的单元数
|
||
dropout_rate: Dropout层的丢弃率
|
||
regularization_rate: L2正则化系数
|
||
|
||
Returns:
|
||
构建好的模型
|
||
"""
|
||
# 定义正则化器
|
||
regularizer = l2(regularization_rate)
|
||
|
||
# 情感识别模型
|
||
if self.include_language and self.num_languages is not None:
|
||
# 如果包含语言信息,使用函数式API构建模型
|
||
|
||
# 音频特征输入
|
||
audio_input = Input(shape=self.input_shape, name='audio_input')
|
||
|
||
# LSTM处理音频特征
|
||
x = Bidirectional(LSTM(lstm_units, return_sequences=False,
|
||
kernel_regularizer=regularizer,
|
||
recurrent_regularizer=regularizer))(audio_input)
|
||
x = BatchNormalization()(x)
|
||
x = Dropout(dropout_rate)(x)
|
||
|
||
# 语言类别输入
|
||
language_input = Input(shape=(self.num_languages,), name='language_input')
|
||
|
||
# 合并特征
|
||
merged = tf.keras.layers.concatenate([x, language_input])
|
||
|
||
# 全连接层
|
||
x = Dense(128, activation='relu', kernel_regularizer=regularizer)(merged)
|
||
x = BatchNormalization()(x)
|
||
x = Dropout(dropout_rate)(x)
|
||
|
||
# 输出层
|
||
output = Dense(self.num_emotions, activation='softmax', name='emotion_output')(x)
|
||
|
||
# 创建模型
|
||
self.model = Model(inputs=[audio_input, language_input], outputs=output)
|
||
|
||
else:
|
||
# 如果不包含语言信息,使用顺序API构建模型
|
||
self.model = Sequential([
|
||
Bidirectional(LSTM(lstm_units, return_sequences=True,
|
||
kernel_regularizer=regularizer,
|
||
recurrent_regularizer=regularizer),
|
||
input_shape=self.input_shape),
|
||
BatchNormalization(),
|
||
Dropout(dropout_rate),
|
||
|
||
Bidirectional(LSTM(lstm_units, return_sequences=False,
|
||
kernel_regularizer=regularizer,
|
||
recurrent_regularizer=regularizer)),
|
||
BatchNormalization(),
|
||
Dropout(dropout_rate),
|
||
|
||
Dense(128, activation='relu', kernel_regularizer=regularizer),
|
||
BatchNormalization(),
|
||
Dropout(dropout_rate),
|
||
|
||
Dense(64, activation='relu', kernel_regularizer=regularizer),
|
||
BatchNormalization(),
|
||
Dropout(dropout_rate),
|
||
|
||
Dense(self.num_emotions, activation='softmax')
|
||
])
|
||
|
||
return self.model
|
||
|
||
def compile_model(self, learning_rate=0.001):
|
||
"""
|
||
编译模型
|
||
|
||
Args:
|
||
learning_rate: 学习率
|
||
"""
|
||
optimizer = Adam(learning_rate=learning_rate)
|
||
self.model.compile(
|
||
optimizer=optimizer,
|
||
loss='categorical_crossentropy',
|
||
metrics=['accuracy']
|
||
)
|
||
|
||
def get_callbacks(self, checkpoint_path=None):
|
||
"""
|
||
获取回调函数列表
|
||
|
||
Args:
|
||
checkpoint_path: 模型检查点保存路径
|
||
|
||
Returns:
|
||
callbacks: 回调函数列表
|
||
"""
|
||
callbacks = []
|
||
|
||
# 提前停止
|
||
early_stopping = EarlyStopping(
|
||
monitor='val_loss',
|
||
patience=10,
|
||
restore_best_weights=True
|
||
)
|
||
callbacks.append(early_stopping)
|
||
|
||
# 学习率衰减
|
||
reduce_lr = ReduceLROnPlateau(
|
||
monitor='val_loss',
|
||
factor=0.2,
|
||
patience=5,
|
||
min_lr=1e-6
|
||
)
|
||
callbacks.append(reduce_lr)
|
||
|
||
# 模型检查点
|
||
if checkpoint_path:
|
||
# 确保文件扩展名正确
|
||
if self.save_weights_only:
|
||
# 如果只保存权重,确保路径以.weights.h5结尾
|
||
if not checkpoint_path.endswith('.weights.h5'):
|
||
# 修改路径,确保扩展名正确
|
||
checkpoint_path = os.path.splitext(checkpoint_path)[0] + '.weights.h5'
|
||
|
||
checkpoint = ModelCheckpoint(
|
||
filepath=checkpoint_path,
|
||
monitor='val_loss',
|
||
save_best_only=True,
|
||
save_weights_only=self.save_weights_only,
|
||
mode='min',
|
||
verbose=1
|
||
)
|
||
callbacks.append(checkpoint)
|
||
|
||
return callbacks
|
||
|
||
def train(self, X_train, y_train, X_val, y_val, epochs=100, batch_size=32,
|
||
checkpoint_path='best_model.h5', language_train=None, language_val=None):
|
||
"""
|
||
训练模型
|
||
|
||
Args:
|
||
X_train: 训练集特征
|
||
y_train: 训练集标签
|
||
X_val: 验证集特征
|
||
y_val: 验证集标签
|
||
epochs: 训练轮数
|
||
batch_size: 批量大小
|
||
checkpoint_path: 模型检查点保存路径
|
||
language_train: 训练集语言标签(如果include_language为True)
|
||
language_val: 验证集语言标签(如果include_language为True)
|
||
|
||
Returns:
|
||
训练历史
|
||
"""
|
||
callbacks = self.get_callbacks(checkpoint_path)
|
||
|
||
if self.include_language and language_train is not None and language_val is not None:
|
||
# 如果包含语言信息
|
||
history = self.model.fit(
|
||
[X_train, language_train], y_train,
|
||
validation_data=([X_val, language_val], y_val),
|
||
epochs=epochs,
|
||
batch_size=batch_size,
|
||
callbacks=callbacks,
|
||
verbose=1
|
||
)
|
||
else:
|
||
# 如果不包含语言信息
|
||
history = self.model.fit(
|
||
X_train, y_train,
|
||
validation_data=(X_val, y_val),
|
||
epochs=epochs,
|
||
batch_size=batch_size,
|
||
callbacks=callbacks,
|
||
verbose=1
|
||
)
|
||
|
||
return history
|
||
|
||
def evaluate(self, X_test, y_test, language_test=None):
|
||
"""
|
||
评估模型
|
||
|
||
Args:
|
||
X_test: 测试集特征
|
||
y_test: 测试集标签
|
||
language_test: 测试集语言标签(如果include_language为True)
|
||
|
||
Returns:
|
||
评估结果(损失和准确率)
|
||
"""
|
||
if self.include_language and language_test is not None:
|
||
# 如果包含语言信息
|
||
return self.model.evaluate([X_test, language_test], y_test)
|
||
else:
|
||
# 如果不包含语言信息
|
||
return self.model.evaluate(X_test, y_test)
|
||
|
||
def predict(self, X, language=None):
|
||
"""
|
||
预测情感
|
||
|
||
Args:
|
||
X: 特征
|
||
language: 语言标签(如果include_language为True)
|
||
|
||
Returns:
|
||
预测结果
|
||
"""
|
||
if self.include_language and language is not None:
|
||
# 如果包含语言信息
|
||
return self.model.predict([X, language])
|
||
else:
|
||
# 如果不包含语言信息
|
||
return self.model.predict(X)
|
||
|
||
def save(self, filepath):
|
||
"""
|
||
保存模型
|
||
|
||
Args:
|
||
filepath: 保存路径
|
||
"""
|
||
self.model.save(filepath)
|
||
|
||
@classmethod
|
||
def load(cls, filepath, custom_objects=None):
|
||
"""
|
||
加载模型
|
||
|
||
Args:
|
||
filepath: 模型文件路径
|
||
custom_objects: 自定义对象
|
||
|
||
Returns:
|
||
加载的模型
|
||
"""
|
||
model = tf.keras.models.load_model(filepath, custom_objects=custom_objects)
|
||
return model |