当前位置: 首页 > news >正文

RBF(径向基神经网络)基础解析与代码实例:拟合任意函数

目录

1. 前言

2. RBF神经网络原理

2.1 网络结构

2.2 激活函数

2.3 工作原理

3. RBF神经网络实例:拟合任意函数

3.1 导入所需的库 

3.2 定义RBF网络架构

3.3 初始化网络参数

3.4 定义损失函数和优化器

3.5 生成训练数据

3.6 训练模型

3.7 小批量梯度下降(可选)

3.8 可视化结果

3.9 完整代码

4. RBF应用领域

5. 总结


1. 前言

目前RBF这一强大工具已经被基本证实可以拟合任意函数!

在机器学习领域,神经网络是一种强大的工具,用于解决各种复杂问题。其中,径向基函数(RBF,Radial Basis Function)神经网络以其独特的结构和优异的性能,在非线性函数逼近、模式识别等领域表现尤为突出。RBF神经网络是一种前馈神经网络,其隐含层节点的激活函数采用径向基函数,具有更强的局部逼近能力和更快的收敛速度。

2. RBF神经网络原理

2.1 网络结构

BF神经网络通常由三层组成:输入层、隐含层和输出层。输入层接收输入数据,隐含层使用径向基函数对输入数据进行变换,输出层则对隐含层的输出进行线性组合,得到最终的输出结果。

与MLP神经网络不同的是,其中每个x输入的不是一个特征,而是一整个样本 。

2.2 激活函数

隐含层的激活函数通常选用高斯函数,其表达式为:

这里,x是输入向量,ci​是第i个隐含层节点的中心,σi​是第i个隐含层节点的宽度。

径向基函数是一个取值仅仅依赖于离原点距离的实值函数(RBF)方法,任意一个满足该特性的函数都叫做径向基函数,标准的一般使用欧式距离(也叫做欧式径向基函数)。

2.3 工作原理

RBF神经网络的工作原理可以分为两个阶段。在训练阶段,网络通过调整隐含层节点的中心和宽度来拟合训练数据;在预测阶段,网络使用训练好的参数对新数据进行预测。

3. RBF神经网络实例:拟合任意函数

3.1 导入所需的库 

首先,我们需要导入Pytorch和其他的基本库。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

3.2 定义RBF网络架构

接下来,我们定义RBF网络的结构。在此示例中,我们将使用一层RBF中心和一层输出层。

class RBFNetwork(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(RBFNetwork, self).__init__()self.hidden_dim = hidden_dimself.centers = nn.Parameter(torch.randn(hidden_dim, input_dim))self.sigmas = nn.Parameter(torch.randn(hidden_dim))self.linear = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = x.unsqueeze(0)centers = self.centers.unsqueeze(1)distances = torch.sum((x - centers) ** 2, dim=2)rbf_activations = torch.exp(-distances / (2 * self.sigmas ** 2))return self.linear(rbf_activations)

增加一个维度,将输入张量的形状从 (batch_size, input_dim) 转换为 (1, batch_size, input_dim)

增加一个维度,将中心点张量的形状从 (hidden_dim, input_dim) 转换为 (hidden_dim, 1, input_dim)。  

通过广播机制,计算输入数据与每个中心点之间的差值。

torch.sum(..., dim=2):沿着第三个维度(dim=2)求和,得到每个输入点与每个中心点之间的欧几里得距离的平方。

rbf_activations 的形状是 (hidden_dim, batch_size),表示每个输入点在每个RBF单元的激活值。 

self.linear 是一个全连接层,将输入从 hidden_dim 维度转换为 output_dim 维度。

结果:返回一个形状为 (batch_size, output_dim) 的张量,表示网络的最终输出。

3.3 初始化网络参数

创建一个RBF网络实例,并初始化网络参数。

input_dim = 1
hidden_dim = 10
output_dim = 1rbf_net = RBFNetwork(input_dim, hidden_dim, output_dim)

3.4 定义损失函数和优化器

选择一个损失函数和优化器,常用的有均方误差损失和Adam优化器。

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(rbf_net.parameters(), lr=0.01)

3.5 生成训练数据

生成一些简单的训练数据,用于训练RBF网络。

x = torch.linspace(-5, 5, 100).reshape(-1, 1)
y = 2 * x ** 2 + 3 * x + 1

3.6 训练模型

使用生成的数据对模型进行训练,并记录损失值。

num_epochs = 1000for epoch in range(num_epochs):optimizer.zero_grad()outputs = rbf_net(x)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

这种训练方式称为 全批量梯度下降。 

3.7 小批量梯度下降(可选)

如果你希望每次迭代只使用一部分数据(小批量)进行训练,可以使用 小批量梯度下降。这通常通过 DataLoader 来实现,将数据集分成多个小批次,每次迭代只传递一个小批次的数据。

以下是使用 DataLoader 实现小批量梯度下降的示例:

from torch.utils.data import DataLoader, TensorDataset# 创建数据集
dataset = TensorDataset(x, y)
# 创建数据加载器,设置批量大小为32
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)for epoch in range(num_epochs):for batch_x, batch_y in dataloader:optimizer.zero_grad()outputs = rbf_net(batch_x)loss = criterion(outputs, batch_y)loss.backward()optimizer.step()

在这个示例中,每次迭代只传递一个小批次的数据(32个样本),并且数据会在每个 epoch 开始时被打乱(shuffle=True),以确保每次迭代使用不同的数据。

3.8 可视化结果

通过可视化训练数据和预测结果,评估模型的性能。

plt.scatter(x.detach().numpy(), y.detach().numpy(), label='True function')
plt.plot(x.detach().numpy(), rbf_net(x).detach().numpy(), label='Approximated function')
plt.xlabel('x')
plt.ylabel('y')
plt.title('RBF Neural Network for Function Approximation')
plt.legend()
plt.show()

3.9 完整代码

完整代码如下方便调试:

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as pltclass RBFNetwork(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(RBFNetwork, self).__init__()self.hidden_dim = hidden_dimself.centers = nn.Parameter(torch.randn(hidden_dim, input_dim))self.sigmas = nn.Parameter(torch.randn(hidden_dim))self.linear = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = x.unsqueeze(0)  # 形状变为 (1, batch_size, input_dim)centers = self.centers.unsqueeze(1)  # 形状变为 (hidden_dim, 1, input_dim)distances = torch.sum((x - centers) ** 2, dim=2)  # 形状变为 (hidden_dim, batch_size)# 调整分母的形状以正确广播sigmas_sq = self.sigmas ** 2sigmas_sq = sigmas_sq.unsqueeze(1)  # 形状变为 (hidden_dim, 1)denominator = 2 * sigmas_sqrbf_activations = torch.exp(-distances / denominator)rbf_activations = rbf_activations.t()  # 转置为 (batch_size, hidden_dim)return self.linear(rbf_activations)input_dim = 1
hidden_dim = 10
output_dim = 1rbf_net = RBFNetwork(input_dim, hidden_dim, output_dim)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(rbf_net.parameters(), lr=0.01)x = torch.linspace(-5, 5, 100).reshape(-1, 1)
y = 2 * x ** 2 + 3 * x + 1num_epochs = 1000for epoch in range(num_epochs):optimizer.zero_grad()outputs = rbf_net(x)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')plt.scatter(x.detach().numpy(), y.detach().numpy(), label='True function')
plt.plot(x.detach().numpy(), rbf_net(x).detach().numpy(), label='Approximated function')
plt.xlabel('x')
plt.ylabel('y')
plt.title('RBF Neural Network for Function Approximation')
plt.legend()
plt.show()

4. RBF应用领域

1. 函数逼近

RBF网络在非线性函数逼近方面表现出色,可以用于拟合复杂的函数关系。例如:

  • 回归分析:预测连续值,如房价、股票价格等。

  • 数学建模:拟合复杂的数学函数。

2. 模式识别

RBF网络可以用于分类任务,识别模式或类别。例如:

  • 图像识别:识别手写数字、物体分类等。

  • 语音识别:识别语音信号中的模式。

  • 生物特征识别:如指纹识别、面部识别等。

3. 时间序列预测

RBF网络可以用于预测时间序列数据,例如:

  • 股票市场预测:预测股票价格的未来走势。

  • 天气预测:预测未来的天气状况。

  • 销售预测:预测产品销量。

4. 控制系统

RBF网络可以用于设计智能控制系统,例如:

  • 机器人控制:控制机器人的运动和行为。

  • 自动驾驶:用于自动驾驶汽车的路径规划和决策。

  • 工业过程控制:优化工业生产过程。

5. 数据分类与聚类

RBF网络可以用于数据分类和聚类任务,例如:

  • 客户分类:根据客户行为或特征进行分类。

  • 异常检测:识别异常数据点,如信用卡欺诈检测。

6. 信号处理

RBF网络可以用于信号处理任务,例如:

  • 噪声滤波:去除信号中的噪声。

  • 信号分类:识别不同类型的信号,如心电信号(ECG)分类。

7. 生物信息学

RBF网络在生物信息学中有广泛应用,例如:

  • 基因表达分析:分析基因表达数据。

  • 蛋白质结构预测:预测蛋白质的三维结构。

8. 金融领域

RBF网络在金融领域也有广泛应用,例如:

  • 信用评分:评估客户的信用风险。

  • 投资组合优化:优化投资组合的配置。

9. 自然语言处理

RBF网络可以用于自然语言处理任务,例如:

  • 情感分析:分析文本中的情感倾向。

  • 文本分类:对文本进行分类,如新闻分类。

10. 游戏AI

RBF网络可以用于游戏AI,例如:

  • 行为预测:预测玩家的行为。

  • 策略优化:优化游戏中的策略和决策。

5. 总结

在这篇文章里,我们详细地介绍了RBF神经网络的原理,并通过Pytorch实现了一个简单的RBF网络。通过每一步的代码示例和注释,我们展示了如何构建和训练RBF网络。RBF网络具有独特的结构和优异的性能,在非线性函数逼近和模式识别等领域有着广泛的应用。希望这篇文章能帮助大家顺利理解和实现RBF神经网络。我是橙色小博,关注我,一起在人工智能领域学习进步!

相关文章:

  • Java从入门到“放弃”(精通)之旅——类和对象全面解析⑦
  • HBuilder X:前端开发的终极生产力工具
  • 【C语言】srand() rand seed其实是设置一个初始值
  • 百级Function架构集成DeepSeek实践:Go语言超大规模AI工具系统设计
  • kotlin知识体系(五) :Android 协程全解析,从作用域到异常处理的全面指南
  • 深入理解组合实体模式(Composite Entity Pattern)在 C# 中的应用与实现
  • 基于SpringAI Alibaba实现RAG架构的深度解析与实践指南
  • 【数据结构_12】二叉树(4)
  • C 语言的未来:在变革中坚守与前行
  • Windows串口通信
  • 进程管理,关闭进程
  • PCA——主成分分析数学原理及代码
  • 【图像处理基石】什么是去马赛克算法?
  • springboot+vue3+mysql+websocket实现的即时通讯软件
  • 热门算法面试题第19天|Leetcode39. 组合总和40.组合总和II131.分割回文串
  • PyTorch基础笔记
  • 【笔记】SpringBoot实现图片上传和获取图片接口
  • MAC-从es中抽取数据存入表中怎么实现
  • 23种设计模式-结构型模式之适配器模式(Java版本)
  • 23种设计模式-结构型模式之装饰器模式(Java版本)
  • 印度空军计划增购40架法制“阵风”战机,此前已购买36架
  • 商务部24日下午将举行发布会,介绍近期商务领域重点工作情况
  • 从南宋遗韵到海派风情,解码江南服饰美学基因
  • 全国登记在册民营企业超过5700万户
  • “你是做什么的?”——人们能否对工作说不?
  • 玉渊谭天丨先爆视频再爆订单,美关税影响下企业因短视频火出圈