当前位置: 首页 > news >正文

深度学习-torch,全连接神经网路

3. 数据集加载案例

通过一些数据集的加载案例,真正了解数据类及数据加载器。

3.1 加载csv数据集

代码参考如下

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
​
​
class MyCsvDataset(Dataset):def __init__(self, filename):df = pd.read_csv(filename)# 删除文字列df = df.drop(["学号", "姓名"], axis=1)# 转换为tensordata = torch.tensor(df.values)# 最后一列以前的为data,最后一列为labelself.data = data[:, :-1]self.label = data[:, -1]self.len = len(self.data)
​def __len__(self):return self.len
​def __getitem__(self, index):idx = min(max(index, 0), self.len - 1)return self.data[idx], self.label[idx]
​
​
def test001():excel_path = r"./大数据答辩成绩表.csv"dataset = MyCsvDataset(excel_path)dataloader = DataLoader(dataset, batch_size=4, shuffle=True)for i, (data, label) in enumerate(dataloader):print(i, data, label)
​
​
if __name__ == "__main__":test001()
​

练习:上述示例数据构建器改成TensorDataset

def build_dataset(filepath):df = pd.read_csv(filepath)df.drop(columns=['学号', '姓名'], inplace=True)data = df.iloc[..., :-1]labels = df.iloc[..., -1]
​x = torch.tensor(data.values, dtype=torch.float)y = torch.tensor(labels.values)
​dataset = TensorDataset(x, y)
​return dataset
​
​
def test001():filepath = r"./大数据答辩成绩表.csv"dataset = build_dataset(filepath)dataloader = DataLoader(dataset, batch_size=4, shuffle=True)for i, (data, label) in enumerate(dataloader):print(i, data, label)

3.2 加载图片数据集

参考代码如下:只是用于文件读取测试

import torch
from torch.utils.data import Dataset, DataLoader
import os
​
# 导入opencv
import cv2
​
​
class MyImageDataset(Dataset):def __init__(self, folder):# 文件存储路径列表self.filepaths = []# 文件对应的目录序号列表self.labels = []# 指定图片大小self.imgsize = (112, 112)# 临时存储文件所在目录名dirnames = []
​# 递归遍历目录,root:根目录路径,dirs:子目录名称,files:子文件名称for root, dirs, files in os.walk(folder):# 如果dirs和files不同时有值,先遍历dirs,然后再以dirs的目录为路径遍历该dirs下的files# 这里需要在dirs不为空时保存目录名称列表if len(dirs) > 0:dirnames = dirs
​for file in files:# 文件路径filepath = os.path.join(root, file)self.filepaths.append(filepath)# 分割root中的dir目录名classname = os.path.split(root)[-1]# 根据目录名到临时目录列表中获取下标self.labels.append(dirnames.index(classname))self.len = len(self.filepaths)
​def __len__(self):return self.len
​def __getitem__(self, index):# 获取下标idx = min(max(index, 0), self.len - 1)# 根据下标获取文件路径filepath = self.filepaths[idx]# opencv读取图片img = cv2.imread(filepath)# 图片缩放,图片加载器要求同一批次的图片大小一致img = cv2.resize(img, self.imgsize)# 转换为tensorimg_tensor = torch.tensor(img)# 将图片HWC调整为CHWimg_tensor = torch.permute(img_tensor, (2, 0, 1))# 获取目录标签label = self.labels[idx]
​return img_tensor, label
​
​
def test02():path = os.path.join(os.path.dirname(__file__), 'dataset')# 转换为相对路径path = os.path.relpath(path)dataset = MyImageDataset(path)
​dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
​for img, label in dataloader:print(img.shape)print(label)
​
​
if __name__ == "__main__":test02()
​

练习:1.重写上述代码,如果不对图片进行缩放会产生什么结果?2.在遍历图片的代码中打印图片查看图片效果(打印一批次即可)

# 导入opencv
import cv2
​
​
class MyDataset(Dataset):def __init__(self, folder):
​dirnames = []self.filepaths = []self.labels = []
​for root, dirs, files in os.walk(folder):if len(dirs) > 0:dirnames = dirs
​for file in files:filepath = os.path.join(root, file)self.filepaths.append(filepath)classname = os.path.split(root)[-1]if classname in dirnames:self.labels.append(dirnames.index(classname))else:print(f'{classname} not in {dirnames}')
​self.len = len(self.filepaths)
​def __len__(self):return self.len
​def __getitem__(self, index):idx = min(max(index, 0), self.len - 1)filepath = self.filepaths[idx]img = cv2.imread(filepath)print(img.shape)# 不做图片缩放,报:RuntimeError: stack expects each tensor to be equal size, but got [3, 1333, 2000] at entry 0 and [3, 335, 600] at entry 1img = cv2.resize(img, (112, 112))t_img = torch.tensor(img)t_img = torch.permute(t_img, (2, 0, 1))
​label = self.labels[idx]return t_img, label
​
​
def test02():path = os.path.join(os.path.dirname(__file__), 'dataset')dataset = MyDataset(path)dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
​for img, label in dataloader:
​print(img.shape, label)for i in range(img.shape[0]):im = torch.permute(img[i], (1, 2, 0))plt.imshow(im)plt.show()
​break
​
​
if __name__ == "__main__":test02()

优化:使用ImageFolder加载图片集

ImageFolder 会根据文件夹的结构来加载图像数据。它假设每个子文件夹对应一个类别,文件夹名称即为类别名称。例如,一个典型的文件夹结构如下:

root/class1/img1.jpgimg2.jpg...class2/img1.jpgimg2.jpg......

在这个结构中:

  • root 是根目录。

  • class1class2 等是类别名称。

  • 每个类别文件夹中的图像文件会被加载为一个样本。

ImageFolder构造函数如下:

torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, is_valid_file=None)

参数解释

  • root:字符串,指定图像数据集的根目录。

  • transform:可选参数,用于对图像进行预处理。通常是一个 torchvision.transforms 的组合。

  • target_transform:可选参数,用于对目标(标签)进行转换。

  • is_valid_file:可选参数,用于过滤无效文件。如果提供,只有返回 True 的文件才会被加载。

import torch
from torchvision import datasets, transforms
import os
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
​
torch.manual_seed(42)
​
def load():path = os.path.join(os.path.dirname(__file__), 'dataset')print(path)
​transform = transforms.Compose([transforms.Resize((112, 112)),transforms.ToTensor()])
​dataset = datasets.ImageFolder(path, transform=transform)dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
​for x,y in dataloader:x = x.squeeze(0).permute(1, 2, 0).numpy()plt.imshow(x)plt.show()print(y[0])break
​
​
if __name__ == '__main__':load()
​

3.3 加载官方数据集

在 PyTorch 中官方提供了一些经典的数据集,如 CIFAR-10、MNIST、ImageNet 等,可以直接使用这些数据集进行训练和测试。

数据集:Datasets — Torchvision 0.21 documentation

常见数据集:

  • MNIST: 手写数字数据集,包含 60,000 张训练图像和 10,000 张测试图像。

  • CIFAR10: 包含 10 个类别的 60,000 张 32x32 彩色图像,每个类别 6,000 张图像。

  • CIFAR100: 包含 100 个类别的 60,000 张 32x32 彩色图像,每个类别 600 张图像。

  • COCO: 通用对象识别数据集,包含超过 330,000 张图像,涵盖 80 个对象类别。

torchvision.transforms 和 torchvision.datasets 是 PyTorch 中处理计算机视觉任务的两个核心模块,它们为图像数据的预处理和标准数据集的加载提供了强大支持。

transforms 模块提供了一系列用于图像预处理的工具,可以将多个变换组合成处理流水线。

datasets 模块提供了多种常用计算机视觉数据集的接口,可以方便地下载和加载。

参考如下:

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, datasets
​
​
def test():transform = transforms.Compose([transforms.ToTensor(),])# 训练数据集data_train = datasets.MNIST(root="./data",train=True,download=True,transform=transform,)trainloader = DataLoader(data_train, batch_size=8, shuffle=True)for x, y in trainloader:print(x.shape)print(y)break
​# 测试数据集data_test = datasets.MNIST(root="./data",train=False,download=True,transform=transform,)testloader = DataLoader(data_test, batch_size=8, shuffle=True)for x, y in testloader:print(x.shape)print(y)break
​
​
def test006():transform = transforms.Compose([transforms.ToTensor(),])# 训练数据集data_train = datasets.CIFAR10(root="./data",train=True,download=True,transform=transform,)trainloader = DataLoader(data_train, batch_size=4, shuffle=True, num_workers=2)for x, y in trainloader:print(x.shape)print(y)break# 测试数据集data_test = datasets.CIFAR10(root="./data",train=False,download=True,transform=transform,)testloader = DataLoader(data_test, batch_size=4, shuffle=False, num_workers=2)for x, y in testloader:print(x.shape)print(y)break
​
​
if __name__ == "__main__":test()test006()
​

1. 神经网络基础

1.1 生物神经元与人工神经元

神经网络的设计灵感来源于生物神经元。生物神经元通过树突接收信号,细胞核处理信号,轴突传递信号,突触连接不同的神经元。人工神经元模仿了这一过程,接收多个输入信号,经过加权求和和非线性激活函数处理后,输出结果。

1.2 人工神经元的组成

人工神经元由以下几个部分组成:

  • 输入(Inputs)​:代表输入数据,通常用向量表示。
  • 权重(Weights)​:每个输入数据都有一个权重,表示该输入对最终结果的重要性。
  • 偏置(Bias)​:一个额外的可调参数,用于调整模型的输出。
  • 加权求和:将输入乘以对应的权重后求和,再加上偏置。
  • 激活函数(Activation Function)​:将加权求和后的结果转换为输出结果,引入非线性特性。

数学表示如下:

其中,σ(z) 是激活函数。


2. 神经网络结构

2.1 基本结构

神经网络由以下三层构成:

  • 输入层(Input Layer)​:接收外部数据,不进行计算。
  • 隐藏层(Hidden Layer)​:位于输入层和输出层之间,进行特征提取和转换。隐藏层可以有多层,每层包含多个神经元。
  • 输出层(Output Layer)​:产生最终的预测结果或分类结果。

2.2 全连接神经网络

全连接神经网络(Fully Connected Neural Network,FCNN)是前馈神经网络的一种,每一层的神经元与上一层的所有神经元全连接。全连接神经网络常用于图像分类、文本分类等任务。

2.2.1 特点
  • 权重数量大:由于全连接的特点,权重数量较大,计算量大。
  • 学习能力强:能够学习输入数据的全局特征,但对高维数据的局部特征捕捉能力较弱。
2.2.2 计算步骤
  1. 数据传递:输入数据逐层传递到输出层。
  2. 激活函数:每一层的输出通过激活函数处理。
  3. 损失计算:计算预测值与真实值之间的差距。
  4. 反向传播:通过反向传播算法更新权重以最小化损失。

3. 激活函数

激活函数在神经网络中引入非线性,使网络能够处理复杂的任务。以下是几种常见的激活函数及其特点。

3.1 Sigmoid

3.1.1 公式

3.1.2 特点
  • 将输入映射到 (0, 1) 之间,适合处理概率问题。
  • 梯度消失问题严重,容易导致训练速度变慢。
  • 计算成本较高。
3.1.3 应用场景
  • 一般用于二分类问题的输出层。

3.2 Tanh

3.2.1 公式

3.2.2 特点
  • 输出范围为 (-1, 1),是零中心的,有助于加速收敛。
  • 对称性较好,适合隐藏层。
  • 仍然存在梯度消失问题。
3.2.3 应用场景
  • 适用于隐藏层,但不如 ReLU 常用。

3.3 ReLU

3.3.1 公式

3.3.2 特点
  • 计算简单,适合大规模数据训练。
  • 缓解梯度消失问题,适合深层网络。
  • 存在神经元死亡问题,即某些神经元可能永远不被激活。
3.3.3 应用场景
  • 深度学习中最常用的激活函数,适用于隐藏层。

3.4 Leaky ReLU

3.4.1 公式

3.4.2 特点
  • 解决了 ReLU 的神经元死亡问题。
  • 计算简单,但需要调整超参数 α。
3.4.3 应用场景
  • 适用于隐藏层,尤其是 ReLU 效果不佳时。

3.5 Softmax

3.5.1 公式

3.5.2 特点
  • 将输出转化为概率分布,适合多分类问题。
  • 放大差异,使概率最大的类别更突出。
  • 存在数值不稳定性问题,需进行数值调整。
3.5.3 应用场景
  • 用于多分类问题的输出层。

4. 激活函数的选择

4.1 隐藏层

  1. 优先选择 ReLU。
  2. 如果 ReLU 效果不佳,尝试 Leaky ReLU 或其他激活函数。
  3. 避免使用 Sigmoid,可以尝试 Tanh。

4.2 输出层

  1. 二分类问题选择 Sigmoid。
  2. 多分类问题选择 Softmax。

5. 总结

神经网络是深度学习的核心,理解其结构和激活函数的作用至关重要。人工神经元是神经网络的基本单元,通过加权求和和激活函数实现非线性变换。全连接神经网络是最基本的神经网络结构,广泛应用于各类任务。激活函数在神经网络中引入非线性,增强了网络的表达能力。不同激活函数适用于不同的场景,合理选择激活函数可以显著提升模型性能。

相关文章:

  • 《实战AI智能体》——邮件转工单的AI自动化
  • 区块链如何成为智能城市的底层引擎?从数据透明到自动化治理
  • Cursor 生成java测试用例
  • Sa-Token使用指南
  • 微服务调用中的“大对象陷阱”:CPU飙高问题解析与优化
  • qt QGroupButton 实现两个QPushButton的互斥
  • 游戏引擎学习第232天
  • 解决 pip install tts 报错问题-—SadTalker的AI数字人视频—未来之窗超算中心
  • tomcat 的安装与启动
  • FPGA HR Bank如何支持ODELAY问题分析
  • text-decoration: underline;不生效
  • 土建施工员备考经验分享
  • 《软件设计师》复习笔记(14.3)——设计模式
  • Android12 ServiceManager::addService源码解读
  • Django 结合 Vue 实现简单管理系统的详解
  • JDBC 与 MyBatis 详解:从基础到实践
  • 7、生命周期:魔法的呼吸节奏——React 19 新版钩子
  • Qt 入门 5 之其他窗口部件
  • webgl入门实例-11WebGL 视图矩阵 (View Matrix)基本概念
  • 6.6 “3步调用ChatGPT打造高可靠Python调度器,零依赖实现定时任务自动化“
  • 接续驰援,中国政府援缅卫生防疫队出发赴缅
  • 泸州市长余先河已任四川省委统战部常务副部长
  • 重庆网红景点“莲花茶摊”被市民投诉,官方:采纳意见,整改!
  • OpenAI推出全新推理模型o3、o4-mini,以及一个编程智能体
  • 首映|达菲鸭和猪小弟领衔,《乐一通大电影》主打复古
  • 3月份一线城市商品住宅销售价格环比上涨,二三线城市环比总体降幅收窄