magail4autodrive: first commit
This commit is contained in:
BIN
Env/exp_converted/dataset_mapping.pkl
Normal file
BIN
Env/exp_converted/dataset_mapping.pkl
Normal file
Binary file not shown.
BIN
Env/exp_converted/dataset_summary.pkl
Normal file
BIN
Env/exp_converted/dataset_summary.pkl
Normal file
Binary file not shown.
BIN
Env/exp_converted/exp_converted_0/dataset_mapping.pkl
Normal file
BIN
Env/exp_converted/exp_converted_0/dataset_mapping.pkl
Normal file
Binary file not shown.
BIN
Env/exp_converted/exp_converted_0/dataset_summary.pkl
Normal file
BIN
Env/exp_converted/exp_converted_0/dataset_summary.pkl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
41
Env/run_multiagent_env.py
Normal file
41
Env/run_multiagent_env.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from scenario_env import MultiAgentScenarioEnv
|
||||
from Env.simple_idm_policy import ConstantVelocityPolicy
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
|
||||
WAYMO_DATA_DIR = r"/home/zhy/桌面/MAGAIL_TR/Env"
|
||||
|
||||
def main():
|
||||
env = MultiAgentScenarioEnv(
|
||||
config={
|
||||
# "data_directory": AssetLoader.file_path(AssetLoader.asset_path, "waymo", unix_style=False),
|
||||
"data_directory": AssetLoader.file_path(WAYMO_DATA_DIR, "exp_converted", unix_style=False),
|
||||
"is_multi_agent": True,
|
||||
"num_controlled_agents": 3,
|
||||
"horizon": 300,
|
||||
"use_render": True,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
"manual_control": True,
|
||||
},
|
||||
agent2policy=ConstantVelocityPolicy(target_speed=50)
|
||||
)
|
||||
|
||||
obs = env.reset(0
|
||||
)
|
||||
for step in range(10000):
|
||||
actions = {
|
||||
aid: env.controlled_agents[aid].policy.act()
|
||||
for aid in env.controlled_agents
|
||||
}
|
||||
|
||||
obs, rewards, dones, infos = env.step(actions)
|
||||
env.render(mode="topdown")
|
||||
|
||||
if dones["__all__"]:
|
||||
break
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
204
Env/scenario_env.py
Normal file
204
Env/scenario_env.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import numpy as np
|
||||
from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from metadrive.component.vehicle.vehicle_type import DefaultVehicle, vehicle_class_to_type
|
||||
import math
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Union, Dict, AnyStr
|
||||
from metadrive.engine.logger import get_logger, set_log_level
|
||||
from metadrive.type import MetaDriveType
|
||||
|
||||
|
||||
class PolicyVehicle(DefaultVehicle):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.policy = None
|
||||
self.destination = None
|
||||
|
||||
def set_policy(self, policy):
|
||||
self.policy = policy
|
||||
|
||||
def set_destination(self, des):
|
||||
self.destination = des
|
||||
|
||||
def act(self, observation, policy=None):
|
||||
if self.policy is not None:
|
||||
return self.policy.act(observation)
|
||||
else:
|
||||
return self.action_space.sample()
|
||||
|
||||
def before_step(self, action):
|
||||
self.last_position = self.position # 2D vector
|
||||
self.last_velocity = self.velocity # 2D vector
|
||||
self.last_speed = self.speed # Scalar
|
||||
self.last_heading_dir = self.heading
|
||||
if action is not None:
|
||||
self.last_current_action.append(action)
|
||||
self._set_action(action)
|
||||
|
||||
def is_done(self):
|
||||
# arrive or crash
|
||||
pass
|
||||
|
||||
|
||||
vehicle_class_to_type[PolicyVehicle] = "default"
|
||||
|
||||
|
||||
class MultiAgentScenarioEnv(ScenarioEnv):
|
||||
@classmethod
|
||||
def default_config(cls):
|
||||
config = super().default_config()
|
||||
config.update(dict(
|
||||
data_directory=None,
|
||||
num_controlled_agents=3,
|
||||
horizon=1000,
|
||||
))
|
||||
return config
|
||||
|
||||
def __init__(self, config, agent2policy):
|
||||
self.policy = agent2policy
|
||||
self.controlled_agents = {}
|
||||
self.controlled_agent_ids = []
|
||||
self.obs_list = []
|
||||
self.round = 0
|
||||
super().__init__(config)
|
||||
|
||||
def reset(self, seed: Union[None, int] = None):
|
||||
self.round = 0
|
||||
if self.logger is None:
|
||||
self.logger = get_logger()
|
||||
log_level = self.config.get("log_level", logging.DEBUG if self.config.get("debug", False) else logging.INFO)
|
||||
set_log_level(log_level)
|
||||
|
||||
self.lazy_init()
|
||||
self._reset_global_seed(seed)
|
||||
if self.engine is None:
|
||||
raise ValueError("Broken MetaDrive instance.")
|
||||
|
||||
# 记录专家数据中每辆车的位置,接着全部清除,只保留位置等信息,用于后续生成
|
||||
_obj_to_clean_this_frame = []
|
||||
self.car_birth_info_list = []
|
||||
for scenario_id, track in self.engine.traffic_manager.current_traffic_data.items():
|
||||
if scenario_id == self.engine.traffic_manager.sdc_scenario_id:
|
||||
continue
|
||||
else:
|
||||
if track["type"] == MetaDriveType.VEHICLE:
|
||||
_obj_to_clean_this_frame.append(scenario_id)
|
||||
valid = track['state']['valid']
|
||||
first_show = np.argmax(valid) if valid.any() else -1
|
||||
last_show = len(valid) - 1 - np.argmax(valid[::-1]) if valid.any() else -1
|
||||
# id,出现时间,出生点坐标,出生朝向,目的地
|
||||
self.car_birth_info_list.append({
|
||||
'id': track['metadata']['object_id'],
|
||||
'show_time': first_show,
|
||||
'begin': (track['state']['position'][first_show, 0], track['state']['position'][first_show, 1]),
|
||||
'heading': track['state']['heading'][first_show],
|
||||
'end': (track['state']['position'][last_show, 0], track['state']['position'][last_show, 1])
|
||||
})
|
||||
|
||||
for scenario_id in _obj_to_clean_this_frame:
|
||||
self.engine.traffic_manager.current_traffic_data.pop(scenario_id)
|
||||
|
||||
self.engine.reset()
|
||||
self.reset_sensors()
|
||||
self.engine.taskMgr.step()
|
||||
|
||||
self.lanes = self.engine.map_manager.current_map.road_network.graph
|
||||
|
||||
if self.top_down_renderer is not None:
|
||||
self.top_down_renderer.clear()
|
||||
self.engine.top_down_renderer = None
|
||||
|
||||
self.dones = {}
|
||||
self.episode_rewards = defaultdict(float)
|
||||
self.episode_lengths = defaultdict(int)
|
||||
|
||||
self.controlled_agents.clear()
|
||||
self.controlled_agent_ids.clear()
|
||||
|
||||
super().reset(seed) # 初始化场景
|
||||
self._spawn_controlled_agents()
|
||||
|
||||
return self._get_all_obs()
|
||||
|
||||
def _spawn_controlled_agents(self):
|
||||
# ego_vehicle = self.engine.agent_manager.active_agents.get("default_agent")
|
||||
# ego_position = ego_vehicle.position if ego_vehicle else np.array([0, 0])
|
||||
for car in self.car_birth_info_list:
|
||||
if car['show_time'] == self.round:
|
||||
agent_id = f"controlled_{car['id']}"
|
||||
|
||||
vehicle = self.engine.spawn_object(
|
||||
PolicyVehicle,
|
||||
vehicle_config={},
|
||||
position=car['begin'],
|
||||
heading=car['heading']
|
||||
)
|
||||
vehicle.reset(position=car['begin'], heading=car['heading'])
|
||||
|
||||
vehicle.set_policy(self.policy)
|
||||
vehicle.set_destination(car['end'])
|
||||
|
||||
self.controlled_agents[agent_id] = vehicle
|
||||
self.controlled_agent_ids.append(agent_id)
|
||||
|
||||
# ✅ 关键:注册到引擎的 active_agents,才能参与物理更新
|
||||
self.engine.agent_manager.active_agents[agent_id] = vehicle
|
||||
|
||||
def _get_all_obs(self):
|
||||
# position, velocity, heading, lidar, navigation, TODO: trafficlight -> list
|
||||
self.obs_list = []
|
||||
for agent_id, vehicle in self.controlled_agents.items():
|
||||
state = vehicle.get_state()
|
||||
|
||||
traffic_light = 0
|
||||
for lane in self.lanes.values():
|
||||
if lane.lane.point_on_lane(state['position'][:2]):
|
||||
if self.engine.light_manager.has_traffic_light(lane.lane.index):
|
||||
traffic_light = self.engine.light_manager._lane_index_to_obj[lane.lane.index].status
|
||||
if traffic_light == 'TRAFFIC_LIGHT_GREEN':
|
||||
traffic_light = 1
|
||||
elif traffic_light == 'TRAFFIC_LIGHT_YELLOW':
|
||||
traffic_light = 2
|
||||
elif traffic_light == 'TRAFFIC_LIGHT_RED':
|
||||
traffic_light = 3
|
||||
else:
|
||||
traffic_light = 0
|
||||
break
|
||||
|
||||
lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.dynamic_world)
|
||||
side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
lane_line_lidar = self.engine.get_sensor("lane_line_detector").perceive(num_lasers=10, distance=3,
|
||||
base_vehicle=vehicle,
|
||||
physics_world=self.engine.physics_world.static_world)
|
||||
|
||||
obs = (state['position'][:2] + list(state['velocity']) + [state['heading_theta']]
|
||||
+ lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light]
|
||||
+ list(vehicle.destination))
|
||||
self.obs_list.append(obs)
|
||||
return self.obs_list
|
||||
|
||||
def step(self, action_dict: Dict[AnyStr, Union[list, np.ndarray]]):
|
||||
self.round += 1
|
||||
|
||||
for agent_id, action in action_dict.items():
|
||||
if agent_id in self.controlled_agents:
|
||||
self.controlled_agents[agent_id].before_step(action)
|
||||
|
||||
self.engine.step()
|
||||
|
||||
for agent_id in action_dict:
|
||||
if agent_id in self.controlled_agents:
|
||||
self.controlled_agents[agent_id].after_step()
|
||||
|
||||
self._spawn_controlled_agents()
|
||||
obs = self._get_all_obs()
|
||||
rewards = {aid: 0.0 for aid in self.controlled_agents}
|
||||
dones = {aid: False for aid in self.controlled_agents}
|
||||
dones["__all__"] = self.episode_step >= self.config["horizon"]
|
||||
infos = {aid: {} for aid in self.controlled_agents}
|
||||
return obs, rewards, dones, infos
|
||||
18
Env/simple_idm_policy.py
Normal file
18
Env/simple_idm_policy.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import numpy as np
|
||||
|
||||
class ConstantVelocityPolicy:
|
||||
def __init__(self, target_speed=50):
|
||||
self.step_num = 0
|
||||
|
||||
def act(self):
|
||||
self.step_num += 1
|
||||
if self.step_num % 30 < 15:
|
||||
throttle = 1.0
|
||||
else:
|
||||
throttle = 1.0
|
||||
|
||||
steering = 0.1
|
||||
|
||||
# return [steering, throttle]
|
||||
|
||||
return [0.0,0.05]
|
||||
14
Env/utils.py
Normal file
14
Env/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
|
||||
def set_seed(seed):
|
||||
if seed == -1:
|
||||
seed = np.random.randint(0, 10000)
|
||||
print('Random seed: {}'.format(seed))
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
Reference in New Issue
Block a user