当前位置: 首页 > news >正文

Tensorflow实现用接口调用模型训练和停止训练功能

语言:Python
框架:Flask、Tensorflow
功能描述:存在两个接口,一个接口实现开始训练模型的功能,一个接口实现停止训练的功能。
实现:用一个全局变量存储在训练中的模型。

# 存储所有训练任务
training_tasks = {}
# 训练模型的接口
@train_model.route("/train", methods=["POST"])
def train():try:data = request.get_data()data = json.loads(data)print(data)modelId = data["modelId"]if modelId in training_tasks:return {"success": False, "message": f"{modelId} 已经在训练中"}stop_event = threading.Event()# 在新线程中启动训练train_thread = threading.Thread(target=start_train,args=(data, stop_event))training_tasks[modelId] = {'thread': train_thread,'stop_event': stop_event}train_thread.start()return {"success": "success", "message": "开始训练"}except Exception as e:return  {"success": False, "message": str(e)}
def start_train(data, stop_event):try:# 获取任务参数modelId = data["modelId"]except Exception as e:response_data = {"success": False, "message": str(e)}return response_data
class StopTrainingCallback(keras.callbacks.Callback):def __init__(self, model, modelId, stop_event):super().__init__()self.model = modelself.modelId = modelIdself.stop_event = stop_eventdef on_train_begin(self, logs=None):if self.stop_event.is_set():self.model.stop_training = True # 设置此标志会使model.fit提前终止print(f"训练在开始前被停止")def on_batch_begin(self, batch, logs=None):if self.stop_event.is_set():self.model.stop_training = True # 设置此标志会使model.fit提前终止print(f"训练在批次被停止")# 强制抛出一个异常以确保立即停止raise KeyboardInterrupt("训练被用户停止")
# 模型真正训练的函数
def start_train(data, stop_event):# 定义模型及训练数据model = "xxx"modelId = "xxx"train_dataset = "xxx"test_dataset = "xxx"train_steps = len(list(train_dataset))test_steps = len(list(test_dataset))epochs = "xxx"stoptrainingcallback = StopTrainingCallback(model, modelId, stop_event)try:# 在开始训练前立即检查停止事件if stop_event.is_set():log.info(f"训练 {modelId} 在开始前被停止")callback_log.info("模型训练在开始前被停止")raise KeyboardInterrupt("Training stopped before start")model.fit(train_dataset,steps_per_epoch=train_steps,epochs=epochs,verbose=2,shuffle=True,validation_data=test_dataset,validation_steps=test_steps,callbacks=[stoptrainingcallback])response_data = {"success": True, "message": "Success"}except KeyboardInterrupt:response_data = {"success": False, "message": "模型训练被用户停止."}except tf.errors.ResourceExhaustedError as e:# 显存不足错误response_data = {"success": False, "message": "GPU内存不足,请调整训练参数."}except Exception as e:print("模型训练失败")response_data = {"success": False, "message": str(e)}finally:if data["modelId"] in training_tasks:del training_tasks[data["modelId"]]return response_data
# 停止训练的接口
@stop_train.route('/stop', methods=['POST'])
def stop():data = request.get_data()try:data = json.loads(data)modelId = data.get("modelId",'') # 每个模型有一个唯一的UUIDif modelId == '':return jsonify({"success": False, "message": "modelId为空,无法停止训练.", "data": ''})except Exception as e:print("停止模型训练接口请求数据出错:", str(e))return jsonify({"success": False, "message": "参数错误.", "data": ''})# 调用服务层停止训练result = stop_train_service(modelId)print(result["message"])# 返回响应return jsonify(result)
# 调用服务层停止训练
def stop_train_service(modelId):# 检查模型是否存在if modelId not in training_tasks:return {"success": "error", "message": f"没有找到模型 {modelId} 的训练任务"}# 获取停止事件并设置stop_event = training_tasks[modelId].get('stop_event')if stop_event:stop_event.set()# 清理任务记录del training_tasks[modelId]return {"success": "success", "message": f"停止 {modelId} 模型训练的请求已发送"}else:return {"success": "error", "message": f"模型 {modelId} 的停止训练事件不存在"}

相关文章:

  • openGauss基于PITR恢复测试
  • 第五章 制作工具优化
  • VUE简介
  • 【 图像梯度处理,图像边缘检测】图像处理(OpenCv)-part6
  • C++(17):通过filesystem获取文件的大小
  • electron 渲染进程按钮创建新window,报BrowserWindow is not a constructor错误;
  • 【go】什么是Go语言的GPM模型?工作流程?为什么Go语言中的GMP模型需要有P?
  • 好数对的数目
  • MySQL事务详解
  • C#如何动态生成实体类?5种方法详解与实战演示
  • 《TIME-LLM: TIME SERIES FORECASTINGBY REPROGRAMMING LARGE LANGUAGE MODELS》
  • 51单片机实验三:数码管动态显示
  • 游戏引擎学习第233天
  • 基于Redis的4种延时队列实现方式
  • AI数据分析与BI可视化结合:解锁企业决策新境界
  • HTML新标签与核心 API 实战
  • 杂记-LeetCode中部分题思路详解与笔记-HOT100篇-其四
  • LVGL学习(二)——控件
  • ArcPy工具箱制作(下)
  • 【Hot100】41. 缺失的第一个正数
  • 从高铁到住房:“富足议程”能否拯救美国的进步主义?
  • 中央和国家机关工委建立健全整治形式主义为基层减负长效机制
  • 姜宏出任康复大学分管日常工作的副校长,明确为正厅级
  • 分析|开门红:一季度GDP增长5.4%超预期,市场活力信心增强
  • 阿坝州市监局公布一批典型案例,有加油站篡改加油枪计量器
  • 履新正部级的李成钢,现已担任商务部党组副书记