当前位置: 首页 > news >正文

python pytorch tensorflow transforms 模型培训脚本

环境准备
https://www.doubao.com/thread/w5e26d6401c003bb2

执行培训脚本

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW
import numpy as np


# 自定义数据集类
class SentimentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


# 新增问答数据
qa_texts = ["李白是那个朝代的诗人?", "地球的卫星是什么?", "中国的首都是哪里?", "水的化学式是什么?", "苹果公司的创始人是谁?"]
qa_labels = [1, 2, 3, 4, 5]  # 为每个问题分配一个唯一的标签

# 合并数据
all_texts = qa_texts
all_labels = qa_labels

# 初始化 tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# 创建数据集和数据加载器
dataset = SentimentDataset(all_texts, all_labels, tokenizer, max_length=128)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 初始化模型,注意 num_labels 需要根据总标签数调整
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=8)

# 定义优化器,加入 weight_decay 进行正则化
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 早停策略相关参数
best_loss = float('inf')
patience = 3
counter = 0

for epoch in range(20):  # 增加训练轮数,但结合早停策略
    total_loss = 0
    model.train()
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch + 1}, Loss: {avg_loss}')

    # 早停策略
    if avg_loss < best_loss:
        best_loss = avg_loss
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered!")
            break

# 处理提问的函数
def handle_query(model, tokenizer, query_text, device):
    model.eval()
    with torch.no_grad():
        encoding = tokenizer(query_text, return_tensors='pt', padding=True, truncation=True, max_length=128).to(device)
        logits = model(**encoding).logits
        predictions = np.argmax(logits.cpu().numpy(), axis=1)
        return predictions

# 情感标签映射
label_mapping = {
    1: "唐朝",
    2: "月球",
    3: "北京",
    4: "H2O",
    5: "史蒂夫·乔布斯、史蒂夫·沃兹尼亚克和罗恩·韦恩"
}

# 针对新增的5条数据提问
query_text = ["李白是那个朝代的诗人?", "地球的卫星是什么?", "中国的首都是哪里?", "水的化学式是什么?", "苹果公司的创始人是谁?"]
result = handle_query(model, tokenizer, query_text, device)
readable_result = [label_mapping[pred] for pred in result]
print("Query result:", readable_result)


相关文章:

  • dijkstra算法——47. 参加科学大会
  • VulnHub-matrix-breakout-2-morpheus通关攻略
  • 2025年人工智能、数字媒体技术与社会计算国际学术会议
  • Python字符串格式化全面指南:f-string与常用方法详解
  • pyqt 按钮自动布局方案
  • Hadoop•常用命令
  • LS-NET-006-思科MDS 9148S 查看内存
  • Python:多态,静态方法和类方法
  • golang 生成单元测试报告
  • 目标检测——清洗数据
  • Java 填充 PDF 模版
  • Python个人学习笔记(18):模块(异常处理、traceback、日志记录)
  • MAC-在使用@Async注解的方法时,分布式锁管理和释放
  • STM32原理性知识
  • 一种基于大规模语言模型LLM的数据分析洞察生成方法
  • 如何在 Node.js 中使用 .env 文件管理环境变量 ?
  • Rust嵌入式开发环境搭建指南(基于Stm32+Vscode)
  • ASP3605同步降压调节器——满足汽车电子严苛要求的电源芯片方案
  • 数学之握手问题
  • Java替换jar包中class文件
  • 戴昕谈隐私、数据、声誉与法律现实主义
  • 公安部知识产权犯罪侦查局:侦破盗录传播春节档院线电影刑案25起
  • 国家发改委党组在《人民日报》发表署名文章:新时代新征程民营经济发展前景广阔大有可为
  • 三亚亚龙湾3.4公里岸线近岸海域使用权挂牌出让,起始价近九千万
  • 漫画阅读APP刊载1200余部侵权作品:20人获刑,案件罚金超千万元
  • 李公明︱一周书记:大学的价值、韧性以及……不相称的对抗