当前位置: 首页 > news >正文

【动手学强化学习】番外7-MAPPO应用框架2学习与复现

文章目录

  • 一、待解决问题
    • 1.1 问题描述
    • 1.2 解决方法
  • 二、方法详述
    • 2.1 必要说明
    • 2.2 应用步骤
      • 2.2.1 搭建基础环境
      • 2.2.2 代码复现及结果展示
      • 2.2.3 代码框架理解
  • 三、疑问
  • 四、总结


一、待解决问题

1.1 问题描述

为了在自建环境上使用mappo算法,偶然发现了大佬使用的mappo框架,对其学习理解并复现。

1.2 解决方法

(1)搭建基础环境
(2)代码复现及结果展示
(3)代码框架理解

二、方法详述

2.1 必要说明

此处mappo应用框架的来源为,名为simple_mappo
【史上最简】300行代码搞定 MAPPO算法原理+代码讲解

Github的源码分享为:
https://github.com/Guowei-Zou/simple_mappo.git

✅非常感谢大佬的分享!!!

2.2 应用步骤

2.2.1 搭建基础环境

基于conda创建虚拟环境,安装torch,并安装框架中实例应用的多智能体环境ma-gym

具体而言,ma-gym的源码为:
https://github.com/koulanurag/ma-gym

其中,对于gym版本和pipsetuptoolswheel版本有要求。

conda create -n simple_mappo python=3.11
conda activate simple_mappo

pip install pip==24.0
pip install setuptools==65.5.0
pip install wheel==0.38.4

pip3 install torch torchvision torchaudio
pip install tqdm
conda install matplotlib

将ma-gym源码下载到本地后

cd ma-gym-master/
pip install -e .
conda list | grep gym
conda list | grep ma-gym
conda install spyder

查询gymma-gym的版本分别为:0.20.0,0.0.14

2.2.2 代码复现及结果展示

打开python 的IDE(这里选用spyder),将simple_mappo 的源码运行。

import os
import shutil
from time import sleep
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
#sys.path.append("./ma-gym-master")
from ma_gym.envs.combat.combat import Combat

# 清空日志文件和绘图文件夹
log_file = "training_log_metrics_weight.txt"
plot_dir = "plots_metrics_weight"

if os.path.exists(log_file):
    open(log_file, "w").close()

if os.path.exists(plot_dir):
    shutil.rmtree(plot_dir)
os.makedirs(plot_dir)

# 日志记录函数
def log_message(message):
    with open(log_file, "a") as f:
        f.write(message + "\n")


def plot_all_metrics(metrics_dict, episode):
    """
    将所有指标绘制到一个包含多个子图的图表中
    - 对曲线进行平滑处理
    - 添加误差带显示
    
    参数:
    metrics_dict: 包含所有指标数据的字典,格式为 {metric_name: values_list}
    episode: 当前的episode数
    """
    # 创建一个2x3的子图布局
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle(f'Training Metrics (Up to Episode {episode})', fontsize=16)
    
    # 压平axes数组以便迭代
    axes = axes.flatten()
    
    # 为每个指标获取x轴值
    any_metric = list(metrics_dict.values())[0]
    x_values = [50 * (i + 1) for i in range(len(any_metric))]
    
    # 平滑参数 - 窗口大小
    window_size = min(5, len(x_values)) if len(x_values) > 0 else 1
    
    # 在每个子图中绘制一个指标
    for i, (metric_name, values) in enumerate(metrics_dict.items()):
        if i >= 5:  # 我们只有5个指标
            break
            
        ax = axes[i]
        values_array = np.array(values)
        
        # 应用平滑处理
        if len(values) > window_size:
            # 创建平滑曲线
            smoothed = np.convolve(values_array, np.ones(window_size)/window_size, mode='valid')
            
            # 计算滚动标准差用于误差带
            std_values = []
            for j in range(len(values) - window_size + 1):
                std_values.append(np.std(values_array[j:j+window_size]))
            std_values = np.array(std_values)
            
            # 调整x轴以匹配平滑后的数据长度
            smoothed_x = x_values[window_size-1:]
            
            # 绘制平滑曲线和原始散点
            ax.plot(smoothed_x, smoothed, '-', linewidth=2, label='Smoothed')
            ax.scatter(x_values, values, alpha=0.3, label='Original')
            
            # 添加误差带
            ax.fill_between(smoothed_x, smoothed-std_values, smoothed+std_values, 
                           alpha=0.2, label='±1 StdDev')
        else:
            # 如果数据点太少,只绘制原始数据
            ax.plot(x_values, values, 'o-', label='Data')
        
        ax.set_title(metric_name.replace('_', ' '))
        ax.set_xlabel('Episodes')
        ax.set_ylabel(metric_name.replace('_', ' '))
        ax.grid(True, alpha=0.3)
        ax.legend()
    
    # 删除未使用的子图
    if len(metrics_dict) < 6:
        fig.delaxes(axes[5])
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(os.path.join(plot_dir, f'training_metrics.png'))
    plt.close(fig)

def compute_entropy(probs):
    dist = torch.distributions.Categorical(probs)
    return dist.entropy().mean().item()

def compute_advantage(gamma, lmbda, td_delta):
    td_delta = td_delta.detach().cpu().numpy()
    advantage_list = []
    advantage = 0.0
    for delta in td_delta[::-1]:
        advantage = gamma * lmbda * advantage + delta
        advantage_list.append(advantage)
    advantage_list.reverse()
    return torch.tensor(advantage_list, dtype=torch.float)


# 策略网络(Actor)
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc2(F.relu(self.fc1(x))))
        return F.softmax(self.fc3(x), dim=1)

# 全局价值网络(CentralValueNet)
# 输入: 所有智能体的状态拼接 (team_size * state_dim)
# 输出: 对每个智能体的价值估计 (team_size维向量)
class CentralValueNet(torch.nn.Module):
    def __init__(self, total_state_dim, hidden_dim, team_size):
        super(CentralValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(total_state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, team_size)  # 输出为每个智能体一个价值

    def forward(self, x):
        x = F.relu(self.fc2(F.relu(self.fc1(x))))
        return self.fc3(x)  # [batch, team_size]


class MAPPO:
    def __init__(self, team_size, state_dim, hidden_dim, action_dim,
                 actor_lr, critic_lr, lmbda, eps, gamma, device):
        self.team_size = team_size
        self.gamma = gamma
        self.lmbda = lmbda
        self.eps = eps
        self.device = device

        # 为每个智能体一个独立的actor
        self.actors = [PolicyNet(state_dim, hidden_dim, action_dim).to(device)
                       for _ in range(team_size)]

        # 一个全局critic,输入为所有智能体状态拼接
        self.critic = CentralValueNet(team_size * state_dim, hidden_dim, team_size).to(device)

        self.actor_optimizers = [torch.optim.Adam(actor.parameters(), actor_lr) 
                                 for actor in self.actors]
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), critic_lr)

    def save_model(self, path="mappo_weights"):
        if not os.path.exists(path):
            os.makedirs(path)
        for i, actor in enumerate(self.actors):
            torch.save(actor.state_dict(), os.path.join(path, f"actor_{i}.pth"))
        torch.save(self.critic.state_dict(), os.path.join(path, "critic.pth"))

    def load_model(self, path="mappo_weights"):
        for i, actor in enumerate(self.actors):
            actor_path = os.path.join(path, f"actor_{i}.pth")
            if os.path.exists(actor_path):
                actor.load_state_dict(torch.load(actor_path))
        critic_path = os.path.join(path, "critic.pth")
        if os.path.exists(critic_path):
            self.critic.load_state_dict(torch.load(critic_path))

    def take_action(self, state_per_agent):
        # state_per_agent: list of shape [team_size, state_dim]
        actions = []
        action_probs = []
        for i, actor in enumerate(self.actors):
            s = torch.tensor([state_per_agent[i]], dtype=torch.float).to(self.device)
            probs = actor(s)
            action_dist = torch.distributions.Categorical(probs)
            action = action_dist.sample()
            actions.append(action.item())
            action_probs.append(probs.detach().cpu().numpy()[0])
        return actions, action_probs

    def update(self, transition_dicts, state_dim):

        # 拼接所有智能体的数据,用于全局critic
        # 首先统一长度T,假设所有智能体长度相同(因为同步环境步)
        T = len(transition_dicts[0]['states'])
        # print(f"T: {T}")
        # 将所有智能体在同一时间步的state拼接起来,得到 [T, team_size*state_dim]
        states_all = []
        next_states_all = []
        for t in range(T):
            concat_state = []
            concat_next_state = []
            for i in range(self.team_size):
                concat_state.append(transition_dicts[i]['states'][t])
                concat_next_state.append(transition_dicts[i]['next_states'][t])
            states_all.append(np.concatenate(concat_state))
            next_states_all.append(np.concatenate(concat_next_state))

        states_all = torch.tensor(states_all, dtype=torch.float).to(self.device)  # [T, team_size*state_dim]
        next_states_all = torch.tensor(next_states_all, dtype=torch.float).to(self.device) # [T, team_size*state_dim]

        rewards_all = torch.tensor([ [transition_dicts[i]['rewards'][t] for i in range(self.team_size)] 
                                     for t in range(T)], dtype=torch.float).to(self.device) # [T, team_size]
        dones_all = torch.tensor([ [transition_dicts[i]['dones'][t] for i in range(self.team_size)] 
                                   for t in range(T)], dtype=torch.float).to(self.device) # [T, team_size]

        # 从critic计算价值和TD-target
        values = self.critic(states_all) # [T, team_size]    
        next_values = self.critic(next_states_all) # [T, team_size]

        td_target = rewards_all + self.gamma * next_values * (1 - dones_all) # [T, team_size]
        td_delta = td_target - values # [T, team_size]

        # 为每个智能体计算其优势
        advantages = []
        for i in range(self.team_size):
            adv_i = compute_advantage(self.gamma, self.lmbda, td_delta[:, i])
            advantages.append(adv_i.to(self.device))  # [T]

        # 更新critic
        # critic的loss是所有智能体的均方误差平均
        critic_loss = F.mse_loss(values, td_target.detach())
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # 更新每个智能体的actor
        action_losses = []
        entropies = []

        for i in range(self.team_size):
            states = torch.tensor(transition_dicts[i]['states'], dtype=torch.float).to(self.device)
            actions = torch.tensor(transition_dicts[i]['actions']).view(-1, 1).to(self.device)
            old_probs = torch.tensor(transition_dicts[i]['action_probs'], dtype=torch.float).to(self.device)

            current_probs = self.actors[i](states) # [T, action_dim]
            log_probs = torch.log(current_probs.gather(1, actions))
            old_log_probs = torch.log(old_probs.gather(1, actions)).detach()

            ratio = torch.exp(log_probs - old_log_probs)
            surr1 = ratio * advantages[i].unsqueeze(-1)
            surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantages[i].unsqueeze(-1)

            action_loss = torch.mean(-torch.min(surr1, surr2))
            entropy_val = compute_entropy(current_probs)

            self.actor_optimizers[i].zero_grad()
            action_loss.backward()
            self.actor_optimizers[i].step()

            action_losses.append(action_loss.item())
            entropies.append(entropy_val)

        return np.mean(action_losses), critic_loss.item(), np.mean(entropies)


# 参数设置
actor_lr = 3e-4
critic_lr = 1e-3
total_episodes = 1000
hidden_dim = 64
gamma = 0.99
lmbda = 0.97
eps = 0.3
team_size = 2
grid_size = (20, 20)
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# 创建环境
env = Combat(grid_shape=grid_size, n_agents=team_size, n_opponents=team_size)
state_dim = env.observation_space[0].shape[0]
action_dim = env.action_space[0].n

# 创建MAPPO智能体(共有team_size个actor, 一个共享critic)
mappo = MAPPO(team_size, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, eps, gamma, device)

# 用于统计指标的列表
total_rewards_per_episode = []
episode_lengths = []
policy_losses = []
value_losses = []
entropies = []

# 每50个episode的平均值列表
avg_total_rewards_per_50 = []
avg_episode_length_per_50 = []
avg_policy_loss_per_50 = []
avg_value_loss_per_50 = []
avg_entropy_per_50 = []

with tqdm(total=total_episodes, desc="Training") as pbar:
    for episode in range(1, total_episodes + 1):
        # 初始化Trajectory  buffer
        buffers = [{
            'states': [], 
            'actions': [], 
            'next_states': [], 
            'rewards': [], 
            'dones': [], 
            'action_probs': []
        } for _ in range(team_size)]

        s = env.reset()
        terminal = False
        episode_reward = 0.0
        steps = 0

        while not terminal:
            steps += 1
            # MAPPO中,每个智能体仍独立选择动作,但critic共享
            actions, prob_dists = mappo.take_action(s)
            next_s, r, done, info = env.step(actions)
            # 累计总奖励
            step_reward = sum(r)
            episode_reward += step_reward

            # 存储transition
            for i in range(team_size):
                buffers[i]['states'].append(np.array(s[i]))
                buffers[i]['actions'].append(actions[i])
                buffers[i]['next_states'].append(np.array(next_s[i]))
                buffers[i]['rewards'].append(r[i])
                buffers[i]['dones'].append(float(done[i]))
                buffers[i]['action_probs'].append(prob_dists[i])

            s = next_s
            terminal = all(done)

        

        # 使用MAPPO更新参数
        a_loss, c_loss, ent = mappo.update(buffers, state_dim)

        # 记录指标
        total_rewards_per_episode.append(episode_reward)
        episode_lengths.append(steps)
        policy_losses.append(a_loss)
        value_losses.append(c_loss)
        entropies.append(ent)

        # 保存模型的权重参数
        if episode % 500 == 0:
            mappo.save_model()
            log_message(f"Model saved at episode {episode}")

        # 每50个episode统计一次平均值并记录日志、绘图
        if episode % 50 == 0:
            avg_reward_50 = np.mean(total_rewards_per_episode[-50:])
            avg_length_50 = np.mean(episode_lengths[-50:])
            avg_policy_loss_50 = np.mean(policy_losses[-50:])
            avg_value_loss_50 = np.mean(value_losses[-50:])
            avg_entropy_50 = np.mean(entropies[-50:])

            avg_total_rewards_per_50.append(avg_reward_50)
            avg_episode_length_per_50.append(avg_length_50)
            avg_policy_loss_per_50.append(avg_policy_loss_50)
            avg_value_loss_per_50.append(avg_value_loss_50)
            avg_entropy_per_50.append(avg_entropy_50)

            log_message(f"Episode {episode}: "
                        f"AvgTotalReward(last50)={avg_reward_50:.3f}, "
                        f"AvgEpisodeLength(last50)={avg_length_50:.3f}, "
                        f"AvgPolicyLoss(last50)={avg_policy_loss_50:.3f}, "
                        f"AvgValueLoss(last50)={avg_value_loss_50:.3f}, "
                        f"AvgEntropy(last50)={avg_entropy_50:.3f}")
                
            # 创建指标字典
            metrics_dict = {
                "Average_Total_Reward": avg_total_rewards_per_50,
                "Average_Episode_Length": avg_episode_length_per_50,
                "Average_Policy_Loss": avg_policy_loss_per_50,
                "Average_Value_Loss": avg_value_loss_per_50, 
                "Average_Entropy": avg_entropy_per_50
            }
                
                # 调用新的绘图函数
            plot_all_metrics(metrics_dict, episode)

        pbar.update(1)

最终绘制的结果图如下:

在这里插入图片描述

2.2.3 代码框架理解

后续补充

三、疑问

  1. 暂无

四、总结

  • 后续补充

相关文章:

  • 编译构建 WSO2 产品时的一些注意事项
  • Spring事务同步器在金融系统中的应用:从风控计算到交易投递
  • 车载通信架构 --- DOIP系统机制初入门
  • 五款AI论文工具,助力完成论文写作
  • Konga密码重置
  • Node.js项目开启多进程的2种方案
  • C/C++的数据类型
  • 编程通用-配置文件的选择
  • Django从零搭建卖家中心登陆与注册实战
  • 为了四季度的盈利,李斌的换人还在继续
  • Java Stream深度解析 高阶技巧与性能优化实战
  • 高等数学同步测试卷 同济7版 试卷部分 上 做题记录 第三章微分中值定理与导数的应用同步测试卷 B 卷
  • C++中string库常用函数超详细解析与深度实践
  • Java数组初始化全解析:方式、场景与最佳实践
  • 嵌入式基础(二)ARM基础
  • 瑞幸微RK系列平台的YOLO部署(上篇)
  • Python学习笔记(列表)
  • 使用wpa_cli和wpa_supplicant配置Liunx开发板的wlan0无线网
  • idea的快捷键使用以及相关设置
  • mybatis--多对一处理/一对多处理
  • 美股反弹,纳斯达克中国金龙指数大涨3.69%
  • 民生访谈|电动自行车换新补贴会优化吗?今年汛期情况如何?市应急局回应
  • 文化中国行|1500年水镇枫泾有座丁聪美术馆
  • 老旧高层遭“连环漏水”,如何携手共治解难题?
  • 智慧菜场团标试验:标准化的同时还能保留个性化吗?
  • 同程旅行斥资24.9亿元收购万达酒管:“看好中国酒店管理市场的增长潜力”