当前位置: 首页 > news >正文

第T7周:咖啡豆识别

  •    🍨 本文为🔗365天深度学习训练营中的学习记录博客
  •    🍖 原作者:K同学啊

VGG 网络优缺点分析:

● 优点:

  • 结构简洁统一:整张网络结构统一,只使用 3×3 的小卷积核和 2×2 的最大池化,便于理解和实现。

  • 效果稳定可靠:在多个图像识别任务中表现优异,是深度学习初学者和工业部署常用的经典网络结构之一。

● 缺点:

  1. 参数量大:VGG-16 拥有超过 1 亿个参数,模型体积大(权重文件约 500MB),不适合嵌入式或移动端部署。

  2. 训练耗时长:由于网络较深,训练时间较长,且对计算资源要求较高。

  3. 调参难度高:没有采用跳连接结构,深层网络可能会遇到梯度消失问题,调参过程较为复杂。

一.前期工作

1.设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")

2.导入数据 

 

from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlibdata_dir = "../data/end_data"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.png')))print("图片总数为:",image_count)

二.数据预处理

1.加载数据

batch_size = 8
img_height = 224
img_width = 224

train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

2.可视化数据

plt.figure(figsize=(10, 4))  # 图形的宽为10高为5for images, labels in train_ds.take(1):for i in range(10):ax = plt.subplot(2, 5, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

3.配置数据集

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y))

三.构建VGG-16网络

1.自建模型

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropoutdef VGG16(nb_classes, input_shape):input_tensor = Input(shape=input_shape)# 1st blockx = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)# 2nd blockx = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)# 3rd blockx = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)# 4th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)# 5th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)# full connectionx = Flatten()(x)x = Dense(4096, activation='relu',  name='fc1')(x)x = Dense(4096, activation='relu', name='fc2')(x)output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)model = Model(input_tensor, output_tensor)return modelmodel=VGG16(len(class_names), (img_width, img_height, 3))
model.summary()

四.编译

# 设置初始学习率
initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=30,      # 敲黑板!!!这里是指 steps,不是指epochsdecay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lrstaircase=True)# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['accuracy'])

五.训练模型

epochs = 10history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

六.可视化结果

from datetime import datetime
current_time = datetime.now() # 获取当前时间acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

相关文章:

  • day1-小白学习JAVA---JDK安装和环境变量配置(mac版)
  • AIGC-几款本地生活服务智能体完整指令直接用(DeepSeek,豆包,千问,Kimi,GPT)
  • C#核心笔记——(六)框架基础
  • 《AI大模型应知应会100篇》第25篇:Few-shot与Zero-shot使用方法对比
  • top100 (6-10)
  • 实验五 内存管理实验
  • 使用MQTT协议实现VISION如何与Node-red数据双向通信
  • Excalidraw:一个免费开源的白板绘图工具
  • 电流模式控制学习
  • Java课程内容大纲(附重点与考试方向)
  • Explorer++:轻量级高效文件管理器!!
  • 【AI News | 20250418】每日AI进展
  • 【从零实现高并发内存池】申请、释放内存过程联调测试 与 大于256KB内存申请全攻略
  • 基于用户的协同过滤推荐系统实战项目
  • 【Linux系统篇】:System V IPC核心技术解析---从共享内存到消息队列与信号量
  • Python 高阶函数:日志的高级用法
  • oracle数据库认证大师ocm
  • 成人大学报考-助你跨越信息鸿沟
  • 《从理论到实践:CRC校验的魔法之旅》
  • 简单好用的在线工具
  • 收藏家尤伦斯辞世,曾是中国当代艺术的推手与收藏者
  • 撤销逾千名留学生签证,特朗普政府面临集体诉讼
  • 不降息就走人?特朗普试图开先例罢免美联储主席,有无胜算
  • 关注“老旧小区加装电梯”等安全隐患,最高检发布相关典型案例
  • 财政部关于六起地方政府隐性债务问责典型案例的通报
  • 2025年青年普法志愿者法治文化基层行活动启动