修改了算法代码,并建立了一个简单的训练脚本.修改bert处理二维输入,移除PPO的permute参数

This commit is contained in:
2025-10-22 16:56:12 +08:00
parent b626702cbb
commit 3f7e183c4b
101 changed files with 3837 additions and 39 deletions

View File

@@ -86,6 +86,9 @@ class MultiAgentScenarioEnv(ScenarioEnv):
if self.engine is None:
raise ValueError("Broken MetaDrive instance.")
# 在engine.reset()之前清理对象
self.before_reset()
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
_obj_to_clean_this_frame = []
self.car_birth_info_list = []
@@ -165,10 +168,10 @@ class MultiAgentScenarioEnv(ScenarioEnv):
self.episode_rewards = defaultdict(float)
self.episode_lengths = defaultdict(int)
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
# 调用父类reset会清理场景
super().reset(seed) # 初始化场景
# 重新生成车辆
self._spawn_controlled_agents()
return self._get_all_obs()
@@ -298,6 +301,26 @@ class MultiAgentScenarioEnv(ScenarioEnv):
# ✅ 关键:注册到引擎的 active_agents才能参与物理更新
self.engine.agent_manager.active_agents[agent_id] = vehicle
def before_reset(self):
"""在reset之前清理对象"""
# 清理所有可控车辆
if hasattr(self, 'controlled_agents') and hasattr(self, 'engine'):
# 使用MetaDrive的clear_objects方法清理
if hasattr(self.engine, 'clear_objects'):
try:
self.engine.clear_objects(list(self.controlled_agents.keys()))
except:
pass
# 从agent_manager中移除
if hasattr(self.engine, 'agent_manager'):
for agent_id in list(self.controlled_agents.keys()):
if agent_id in self.engine.agent_manager.active_agents:
self.engine.agent_manager.active_agents.pop(agent_id)
self.controlled_agents.clear()
self.controlled_agent_ids.clear()
def _get_traffic_light_state(self, vehicle):
"""