chuban
This commit is contained in:
278
models/lstm_model.py
Normal file
278
models/lstm_model.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
LSTM模型定义模块,用于构建语音情感识别模型
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Sequential, Model
|
||||
from tensorflow.keras.layers import Input, Dense, LSTM, Dropout, BatchNormalization, Bidirectional
|
||||
from tensorflow.keras.regularizers import l2
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
class EmotionModel:
|
||||
"""
|
||||
语音情感识别模型类
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_emotions, num_languages=None, include_language=False):
|
||||
"""
|
||||
初始化模型
|
||||
|
||||
Args:
|
||||
input_shape: 输入特征的形状,例如 (1, 60) 表示每个样本有1个时间步和60个特征
|
||||
num_emotions: 情感类别数量
|
||||
num_languages: 语言类别数量(如果include_language为True)
|
||||
include_language: 是否将语言作为额外特征
|
||||
"""
|
||||
self.input_shape = input_shape
|
||||
self.num_emotions = num_emotions
|
||||
self.num_languages = num_languages
|
||||
self.include_language = include_language
|
||||
self.model = None
|
||||
self.save_weights_only = True # 默认只保存权重
|
||||
|
||||
def build_model(self, lstm_units=128, dropout_rate=0.3, regularization_rate=0.001):
|
||||
"""
|
||||
构建LSTM模型
|
||||
|
||||
Args:
|
||||
lstm_units: LSTM层的单元数
|
||||
dropout_rate: Dropout层的丢弃率
|
||||
regularization_rate: L2正则化系数
|
||||
|
||||
Returns:
|
||||
构建好的模型
|
||||
"""
|
||||
# 定义正则化器
|
||||
regularizer = l2(regularization_rate)
|
||||
|
||||
# 情感识别模型
|
||||
if self.include_language and self.num_languages is not None:
|
||||
# 如果包含语言信息,使用函数式API构建模型
|
||||
|
||||
# 音频特征输入
|
||||
audio_input = Input(shape=self.input_shape, name='audio_input')
|
||||
|
||||
# LSTM处理音频特征
|
||||
x = Bidirectional(LSTM(lstm_units, return_sequences=False,
|
||||
kernel_regularizer=regularizer,
|
||||
recurrent_regularizer=regularizer))(audio_input)
|
||||
x = BatchNormalization()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# 语言类别输入
|
||||
language_input = Input(shape=(self.num_languages,), name='language_input')
|
||||
|
||||
# 合并特征
|
||||
merged = tf.keras.layers.concatenate([x, language_input])
|
||||
|
||||
# 全连接层
|
||||
x = Dense(128, activation='relu', kernel_regularizer=regularizer)(merged)
|
||||
x = BatchNormalization()(x)
|
||||
x = Dropout(dropout_rate)(x)
|
||||
|
||||
# 输出层
|
||||
output = Dense(self.num_emotions, activation='softmax', name='emotion_output')(x)
|
||||
|
||||
# 创建模型
|
||||
self.model = Model(inputs=[audio_input, language_input], outputs=output)
|
||||
|
||||
else:
|
||||
# 如果不包含语言信息,使用顺序API构建模型
|
||||
self.model = Sequential([
|
||||
Bidirectional(LSTM(lstm_units, return_sequences=True,
|
||||
kernel_regularizer=regularizer,
|
||||
recurrent_regularizer=regularizer),
|
||||
input_shape=self.input_shape),
|
||||
BatchNormalization(),
|
||||
Dropout(dropout_rate),
|
||||
|
||||
Bidirectional(LSTM(lstm_units, return_sequences=False,
|
||||
kernel_regularizer=regularizer,
|
||||
recurrent_regularizer=regularizer)),
|
||||
BatchNormalization(),
|
||||
Dropout(dropout_rate),
|
||||
|
||||
Dense(128, activation='relu', kernel_regularizer=regularizer),
|
||||
BatchNormalization(),
|
||||
Dropout(dropout_rate),
|
||||
|
||||
Dense(64, activation='relu', kernel_regularizer=regularizer),
|
||||
BatchNormalization(),
|
||||
Dropout(dropout_rate),
|
||||
|
||||
Dense(self.num_emotions, activation='softmax')
|
||||
])
|
||||
|
||||
return self.model
|
||||
|
||||
def compile_model(self, learning_rate=0.001):
|
||||
"""
|
||||
编译模型
|
||||
|
||||
Args:
|
||||
learning_rate: 学习率
|
||||
"""
|
||||
optimizer = Adam(learning_rate=learning_rate)
|
||||
self.model.compile(
|
||||
optimizer=optimizer,
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
def get_callbacks(self, checkpoint_path=None):
|
||||
"""
|
||||
获取回调函数列表
|
||||
|
||||
Args:
|
||||
checkpoint_path: 模型检查点保存路径
|
||||
|
||||
Returns:
|
||||
callbacks: 回调函数列表
|
||||
"""
|
||||
callbacks = []
|
||||
|
||||
# 提前停止
|
||||
early_stopping = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=10,
|
||||
restore_best_weights=True
|
||||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
# 学习率衰减
|
||||
reduce_lr = ReduceLROnPlateau(
|
||||
monitor='val_loss',
|
||||
factor=0.2,
|
||||
patience=5,
|
||||
min_lr=1e-6
|
||||
)
|
||||
callbacks.append(reduce_lr)
|
||||
|
||||
# 模型检查点
|
||||
if checkpoint_path:
|
||||
# 确保文件扩展名正确
|
||||
if self.save_weights_only:
|
||||
# 如果只保存权重,确保路径以.weights.h5结尾
|
||||
if not checkpoint_path.endswith('.weights.h5'):
|
||||
# 修改路径,确保扩展名正确
|
||||
checkpoint_path = os.path.splitext(checkpoint_path)[0] + '.weights.h5'
|
||||
|
||||
checkpoint = ModelCheckpoint(
|
||||
filepath=checkpoint_path,
|
||||
monitor='val_loss',
|
||||
save_best_only=True,
|
||||
save_weights_only=self.save_weights_only,
|
||||
mode='min',
|
||||
verbose=1
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
|
||||
return callbacks
|
||||
|
||||
def train(self, X_train, y_train, X_val, y_val, epochs=100, batch_size=32,
|
||||
checkpoint_path='best_model.h5', language_train=None, language_val=None):
|
||||
"""
|
||||
训练模型
|
||||
|
||||
Args:
|
||||
X_train: 训练集特征
|
||||
y_train: 训练集标签
|
||||
X_val: 验证集特征
|
||||
y_val: 验证集标签
|
||||
epochs: 训练轮数
|
||||
batch_size: 批量大小
|
||||
checkpoint_path: 模型检查点保存路径
|
||||
language_train: 训练集语言标签(如果include_language为True)
|
||||
language_val: 验证集语言标签(如果include_language为True)
|
||||
|
||||
Returns:
|
||||
训练历史
|
||||
"""
|
||||
callbacks = self.get_callbacks(checkpoint_path)
|
||||
|
||||
if self.include_language and language_train is not None and language_val is not None:
|
||||
# 如果包含语言信息
|
||||
history = self.model.fit(
|
||||
[X_train, language_train], y_train,
|
||||
validation_data=([X_val, language_val], y_val),
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
callbacks=callbacks,
|
||||
verbose=1
|
||||
)
|
||||
else:
|
||||
# 如果不包含语言信息
|
||||
history = self.model.fit(
|
||||
X_train, y_train,
|
||||
validation_data=(X_val, y_val),
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
callbacks=callbacks,
|
||||
verbose=1
|
||||
)
|
||||
|
||||
return history
|
||||
|
||||
def evaluate(self, X_test, y_test, language_test=None):
|
||||
"""
|
||||
评估模型
|
||||
|
||||
Args:
|
||||
X_test: 测试集特征
|
||||
y_test: 测试集标签
|
||||
language_test: 测试集语言标签(如果include_language为True)
|
||||
|
||||
Returns:
|
||||
评估结果(损失和准确率)
|
||||
"""
|
||||
if self.include_language and language_test is not None:
|
||||
# 如果包含语言信息
|
||||
return self.model.evaluate([X_test, language_test], y_test)
|
||||
else:
|
||||
# 如果不包含语言信息
|
||||
return self.model.evaluate(X_test, y_test)
|
||||
|
||||
def predict(self, X, language=None):
|
||||
"""
|
||||
预测情感
|
||||
|
||||
Args:
|
||||
X: 特征
|
||||
language: 语言标签(如果include_language为True)
|
||||
|
||||
Returns:
|
||||
预测结果
|
||||
"""
|
||||
if self.include_language and language is not None:
|
||||
# 如果包含语言信息
|
||||
return self.model.predict([X, language])
|
||||
else:
|
||||
# 如果不包含语言信息
|
||||
return self.model.predict(X)
|
||||
|
||||
def save(self, filepath):
|
||||
"""
|
||||
保存模型
|
||||
|
||||
Args:
|
||||
filepath: 保存路径
|
||||
"""
|
||||
self.model.save(filepath)
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath, custom_objects=None):
|
||||
"""
|
||||
加载模型
|
||||
|
||||
Args:
|
||||
filepath: 模型文件路径
|
||||
custom_objects: 自定义对象
|
||||
|
||||
Returns:
|
||||
加载的模型
|
||||
"""
|
||||
model = tf.keras.models.load_model(filepath, custom_objects=custom_objects)
|
||||
return model
|
||||
Reference in New Issue
Block a user