当前位置: 首页 > news >正文

基于BERT的序列到序列(Seq2Seq)模型,生成文本摘要或标题

  1. 数据预处理

    • 使用DataGenerator类加载并预处理数据,处理变长序列的padding。
    • 输入为内容(content),目标为标题(title)。
  2. 模型构建

    • 基于BERT构建Seq2Seq模型,使用交叉熵损失。
    • 采用Beam Search进行生成,支持Top-K采样。
  3. 训练与评估

    • 使用Adam优化器进行训练。
    • 每个epoch结束时通过Evaluate回调生成示例标题以观察效果。
import numpy as np
import pandas as pd
from tqdm import tqdm
from bert4keras.bert import build_bert_model
from bert4keras.tokenizer import Tokenizer, load_vocab
from keras.layers import *
from keras.models import Model
from keras import backend as K
from bert4keras.snippets import parallel_apply
from keras.optimizers import Adam
import keras
import math
from sklearn.model_selection import train_test_split
from rouge import Rouge  # 需要安装rouge包

# 配置参数
config_path = 'bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'bert/chinese_L-12_H-768_A-12/vocab.txt'

max_input_len = 256
max_output_len = 32
batch_size = 16
epochs = 10
beam_size = 3
learning_rate = 2e-5
val_split = 0.1

# 数据预处理增强
class DataGenerator(keras.utils.Sequence):
    def __init__(self, data, batch_size=8, mode='train'):
        self.batch_size = batch_size
        self.mode = mode
        self.data = data
        self.indices = np.arange(len(data))

    def __len__(self):
        return math.ceil(len(self.data) / self.batch_size)

    def __getitem__(self, index):
        batch_indices = self.indices[index*self.batch_size : (index+1)*self.batch_size]
        batch = self.data.iloc[batch_indices]
        return self._process_batch(batch)

    def on_epoch_end(self):
        if self.mode == 'train':
            np.random.shuffle(self.indices)

    def _process_batch(self, batch):
        batch_x, batch_y = [], []
        for _, row in batch.iterrows():
            content = row['content'][:max_input_len]
            title = row['title'][:max_output_len-2]  # 保留空间给[CLS]和[SEP]
            
            # 编码器输入
            x, _ = tokenizer.encode(content, max_length=max_input_len)
            
            # 解码器输入输出
            y, _ = tokenizer.encode(title, max_length=max_output_len)
            decoder_input = [tokenizer._token_start_id] + y[:-1]
            decoder_output = y
            
            batch_x.append(x)
            batch_y.append({'decoder_input': decoder_input, 'decoder_output': decoder_output})
        
        # 动态padding
        padded_x = self._pad_sequences([x for x in batch_x], maxlen=max_input_len)
        padded_decoder_input = self._pad_sequences(
            [y['decoder_input'] for y in batch_y], 
            maxlen=max_output_len,
            padding='post'
        )
        padded_decoder_output = self._pad_sequences(
            [y['decoder_output'] for y in batch_y],
            maxlen=max_output_len,
            padding='post'
        )
        
        return [padded_x, padded_decoder_input], padded_decoder_output

    def _pad_sequences(self, sequences, maxlen, padding='pre'):
        padded = np.zeros((len(sequences), maxlen))
        for i, seq in enumerate(sequences):
            if len(seq) > maxlen:
                seq = seq[:maxlen]
            if padding == 'pre':
                padded[i, -len(seq):] = seq
            else:
                padded[i, :len(seq)] = seq
        return padded

# 改进的模型架构
def build_seq2seq_model():
    # 编码器
    encoder_inputs = Input(shape=(None,), name='Encoder-Input')
    encoder = build_bert_model(
        config_path=config_path,
        checkpoint_path=checkpoint_path,
        model='encoder',
        return_keras_model=False,
    )
    encoder_outputs = encoder(encoder_inputs)

    # 解码器
    decoder_inputs = Input(shape=(None,), name='Decoder-Input')
    decoder = build_bert_model(
        config_path=config_path,
        checkpoint_path=checkpoint_path,
        model='decoder',
        application='lm',
        return_keras_model=False,
    )
    decoder_outputs = decoder([decoder_inputs, encoder_outputs])

    # 连接模型
    model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
    
    # 自定义损失函数(忽略padding)
    def seq2seq_loss(y_true, y_pred):
        y_mask = K.cast(K.not_equal(y_true, 0), K.floatx())
        loss = K.sparse_categorical_crossentropy(
            y_true, y_pred, from_logits=True
        )
        return K.sum(loss * y_mask) / K.sum(y_mask)

    model.compile(Adam(learning_rate), loss=seq2seq_loss)
    return model

# 改进的Beam Search
def beam_search(model, input_seq, beam_size=3):
    encoder_input = tokenizer.encode(input_seq)[0]
    encoder_output = model.get_layer('bert').predict(np.array([encoder_input]))
    
    sequences = [[[tokenizer._token_start_id], 0.0]]
    for _ in range(max_output_len):
        all_candidates = []
        for seq, score in sequences:
            if seq[-1] == tokenizer._token_end_id:
                all_candidates.append((seq, score))
                continue
            
            decoder_input = np.array([seq])
            decoder_output = model.get_layer('bert_1').predict(
                [decoder_input, encoder_output]
            )[:, -1, :]
            
            top_k = np.argsort(decoder_output[0])[-beam_size:]
            for token in top_k:
                new_seq = seq + [token]
                new_score = score + np.log(decoder_output[0][token])
                all_candidates.append((new_seq, new_score))
        
        # 长度归一化
        ordered = sorted(all_candidates, key=lambda x: x[1]/(len(x[0])+1e-8), reverse=True)
        sequences = ordered[:beam_size]
    
    best_seq = sequences[0][0]
    return tokenizer.decode(best_seq[1:-1])  # 去除[CLS]和[SEP]

# 增强的评估回调
class AdvancedEvaluate(keras.callbacks.Callback):
    def __init__(self, val_data, sample_size=5):
        self.val_data = val_data
        self.rouge = Rouge()
        self.samples = val_data.sample(sample_size)

    def on_epoch_end(self, epoch, logs=None):
        # 生成示例
        print("\n生成示例:")
        for _, row in self.samples.iterrows():
            generated = beam_search(self.model, row['content'], beam_size)
            print(f"真实标题: {row['title']}")
            print(f"生成标题: {generated}\n")
        
        # 计算ROUGE分数
        references = []
        hypotheses = []
        for _, row in self.val_data.iterrows():
            generated = beam_search(self.model, row['content'], beam_size=1)
            references.append(row['title'])
            hypotheses.append(generated)
        
        scores = self.rouge.get_scores(hypotheses, references, avg=True)
        print(f"验证集ROUGE-L: {scores['rouge-l']['f']:.4f}")

# 主流程
if __name__ == "__main__":
    # 加载数据
    full_data = pd.read_csv('train.tsv', sep='\t', names=['title', 'content'])
    train_data, val_data = train_test_split(full_data, test_size=val_split)

    # 初始化tokenizer
    tokenizer = Tokenizer(dict_path, do_lower_case=True)

    # 构建模型
    model = build_seq2seq_model()
    model.summary()

    # 数据生成器
    train_gen = DataGenerator(train_data, batch_size, mode='train')
    val_gen = DataGenerator(val_data, batch_size, mode='val')

    # 训练配置
    callbacks = [
        AdvancedEvaluate(val_data),
        keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2, verbose=1),
        keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)
    ]

    # 开始训练
    model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        workers=4,
        use_multiprocessing=True
    )

相关文章:

  • vue3,element-plus 表格搜索过滤数据
  • 高效团队开发的工具与方法 引言
  • MySQL是怎么保障ACID特性的
  • Docker 容器基础技术:namespace
  • Python----计算机视觉处理(Opencv:直方图均衡化)
  • 本周安全速报(2025.3.18~3.24)
  • GeoServer与MapServer:两款常用的开源地理空间服务器
  • 通过 ECNWP 洋流、海浪可视化开发的方法和架构
  • 破局AI落地困局 亚信科技“四位一体手术刀“切开产业智能三重枷锁
  • 【嵌入式学习2】内存管理
  • Android Compose 框架的 ViewModel 委托深入剖析(二十)
  • 小试牛刀-Turbine数据分发
  • FPGA_YOLO(二)
  • Python Web 框架之 Flask
  • 全文通读:126页华为IPD集成产品开发与DFX实战【文末附可编辑PPT下载链接】
  • spring+k8s 功能说明
  • Android Compose 框架派生状态(derivedStateOf、rememberCoroutineScope)深入剖析(十五)
  • Qt进程间通信:QSharedMemory 使用详解
  • 2025年陕西省各市秦创原产业创新聚集区(机器人、羊乳、苹果)“四链”融合项目申报补贴要求和时间流程
  • 【STM32】第一个工程的创建
  • 新华每日电讯:从上海街区经济看账面、市面、人面、基本面
  • 新造古镇丨上海古镇朱家角一年接待164万境外游客,凭啥?
  • 上海市政府常务会议研究抓好稳就业稳企业稳市场稳预期工作,让企业感受温度
  • 演员孙俪:中年人没有脆弱的时间,学习胡曼黎不内耗
  • 生态环境法典草案拟初审:应对气候变化等问题将作原则性规定
  • 体坛联播|卡马文加预计伤缺三个月,阿尔卡拉斯因伤退赛