当前位置: 首页 > news >正文

基础的贝叶斯神经网络(BNN)回归

下面是一个最基础的贝叶斯神经网络(BNN)回归示例,采用PyTorch实现,适合入门理解。
这个例子用BNN拟合 y = x + 噪声 的一维回归问题,输出均值和不确定性(方差)。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt# 1. 生成数据
np.random.seed(0)
x = np.linspace(-3, 3, 100)
y = x + np.random.normal(0, 0.5, size=x.shape)# 转为torch tensor
x_train = torch.tensor(x, dtype=torch.float32).unsqueeze(1)
y_train = torch.tensor(y, dtype=torch.float32).unsqueeze(1)# 2. 定义贝叶斯回归网络(输出均值和log方差)
class BayesianRegressor(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Linear(1, 32), nn.ReLU(),nn.Linear(32, 32), nn.ReLU(),nn.Linear(32, 2) # 输出均值和log方差)def forward(self, x):out = self.net(x)mean = out[:, 0:1]logvar = out[:, 1:2]return mean, logvar# 3. 贝叶斯损失函数(负对数似然)
def bayesian_loss(mean, logvar, target):# 对应N(y|mean, exp(logvar))return (0.5 * torch.exp(-logvar) * (target - mean) ** 2 + 0.5 * logvar).mean()# 4. 训练网络
model = BayesianRegressor()
optimizer = optim.Adam(model.parameters(), lr=0.01)for epoch in range(2000):mean, logvar = model(x_train)loss = bayesian_loss(mean, logvar, y_train)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 200 == 0:print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")# 5. 预测与可视化
x_test = torch.linspace(-3, 3, 100).unsqueeze(1)
mean_pred, logvar_pred = model(x_test)
mean_pred = mean_pred.detach().numpy().flatten()
std_pred = torch.exp(0.5 * logvar_pred).detach().numpy().flatten()plt.figure(figsize=(8, 5))
plt.scatter(x, y, label='Data', color='gray', s=10)
plt.plot(x, x, 'g--', label='True function')
plt.plot(x_test, mean_pred, 'b-', label='BNN mean')
plt.fill_between(x_test.flatten(), mean_pred-2*std_pred, mean_pred+2*std_pred, color='orange', alpha=0.3, label='BNN ±2std')
plt.legend()
plt.title("Simple Bayesian Neural Network Regression")
plt.show()

相关文章:

  • 零基础小白如何上岸数模国奖
  • 大学之大:伦敦政治经济学院2025.4.27
  • 【音视频】FFmpeg过滤器框架分析
  • 人工智能—— K-means 聚类算法
  • Spring Cloud Alibaba 整合 Sentinel:实现微服务高可用防护
  • Awesome-Embodied-AI: 具身机器人的资源库
  • [论文梳理] 足式机器人规划控制流程 - 接触碰撞的控制 - 模型误差 - 自动驾驶车的安全合规(4个课堂讨论问题)
  • 【读写视频】MATLAB详细代码
  • 简述删除一个Pod流程?
  • 【计算机组成原理实验】实验一 运算部件实验_加法器及计算机性能指标
  • Redis超详细入门教程(基础篇)
  • 【每日随笔】文化属性 ② ( 高维度信息处理 | 强者思维形成 | 认知重构 | 资源捕获 | 进化路径 )
  • Spark SQL核心概念与编程实战:从DataFrame到DataSet的结构化数据处理
  • Spark-Streaming核心编程(四)总结
  • 关于堆栈指针的那些事 | bootloader 如何跳转app
  • 如何解决无训练数据问题:一种更为智能化的解决方案
  • k8s学习记录(五):Pod亲和性详解
  • AI提示词(Prompt)终极指南:从入门到精通(附实战案例)
  • STM32:看门狗
  • Leetcode刷题记录20——找到字符串中所有字母异位词
  • 上海“生育友好岗”已让4000余人受益,今年将推产假社保补贴政策
  • 新干式二尖瓣瓣膜国内上市,专家:重视瓣膜病全生命周期管理
  • 上海首个航空前置货站落户松江综合保税区,通关效率可提升30%
  • 在上海生活8年,13岁英国女孩把城市记忆写进歌里
  • 高璞任中国第一汽车集团有限公司党委常委、副总经理
  • 旧衣服旧纸箱不舍得扔?可能是因为“囤物障碍”