Update env (#7)

* add capture script

* gymnasium API

* training with gymnasium API
This commit is contained in:
Quanyi Li
2023-06-23 19:50:40 +01:00
committed by GitHub
parent 88b4faa00f
commit 5f5a5b9531
14 changed files with 151 additions and 89 deletions

View File

@@ -4,7 +4,7 @@
[**Webpage**](https://metadriverse.github.io/scenarionet/) |
[**Code**](https://github.com/metadriverse/scenarionet) |
[**Video**](https://metadriverse.github.io/scenarionet/) |
[**Video**](https://youtu.be/3bOqswXP6OA) |
[**Paper**](http://arxiv.org/abs/2306.12241) |
ScenarioNet allows users to load scenarios from real-world dataset like Waymo, nuPlan, nuScenes, l5 and synthetic

View File

@@ -39,14 +39,14 @@ if __name__ == '__main__':
}
)
for index in range(num_scenario if args.scenario_index is not None else 1000000):
env.reset(force_seed=index if args.scenario_index is None else args.scenario_index)
env.reset(seed=index if args.scenario_index is None else args.scenario_index)
for t in range(10000):
o, r, d, info = env.step([0, 0])
env.step([0, 0])
if env.config["use_render"]:
env.render(text={
"scenario index": env.engine.global_seed + env.config["start_scenario_index"],
})
if d and info["arrive_dest"]:
if env.episode_step >= env.engine.data_manager.current_scenario_length:
print("scenario:{}, success".format(env.engine.global_random_seed))
break

View File

@@ -51,9 +51,9 @@ if __name__ == "__main__":
# env.engine.accept("c", capture)
# for seed in [1001, 1002, 1005, 1011]:
env.reset(force_seed=1020)
env.reset(seed=1020)
for t in range(10000):
capture()
o, r, d, info = env.step([1, 0.88])
env.step([1, 0.88])
if env.episode_step >= env.engine.data_manager.current_scenario_length:
break

View File

@@ -28,10 +28,10 @@ if __name__ == "__main__":
)
)
# o = env.reset(force_seed=0)
# o = env.reset(seed=0)
# env.engine.accept("c", capture)
for s in range(6, 1000):
env.reset(force_seed=16)
env.reset(seed=16)
for t in range(10000):
capture()
o, r, d, info = env.step([0, 0])

View File

@@ -56,7 +56,7 @@ if __name__ == "__main__":
# 0,1,3,4,5,6
for seed in range(10):
env.reset(force_seed=seed)
env.reset(seed=seed)
for t in range(10000):
env.capture("rgb_deluxe_{}_{}.jpg".format(env.current_seed, t))
ret = env.render(
@@ -65,8 +65,7 @@ if __name__ == "__main__":
pygame.image.save(ret, "top_down_{}_{}.png".format(env.current_seed, t))
# env.vehicle.get_camera("depth_camera").save_image(env.vehicle, "depth_{}.jpg".format(t))
# env.vehicle.get_camera("rgb_camera").save_image(env.vehicle, "rgb_{}.jpg".format(t))
o, r, d, info = env.step([1, 0.88])
assert env.observation_space.contains(o)
env.step([1, 0.88])
# if d:
if env.episode_step >= env.engine.data_manager.current_scenario_length:
break

View File

@@ -0,0 +1,72 @@
import pygame
from metadrive.engine.asset_loader import AssetLoader
from metadrive.envs.real_data_envs.nuscenes_env import ScenarioEnv
from metadrive.envs.gym_wrapper import GymEnvWrapper
from scenarionet_training.train_utils.utils import initialize_ray, get_function
from scenarionet_training.scripts.train_nuplan import config
if __name__ == "__main__":
initialize_ray(test_mode=False, num_gpus=1)
env = GymEnvWrapper(
dict(
inner_class=ScenarioEnv,
inner_config={
# "data_directory": AssetLoader.file_path("nuscenes", return_raw_style=False),
"data_directory": "D:\\scenarionet_testset\\nuplan_test\\nuplan_test_w_raw",
"use_render": True,
# "agent_policy": ReplayEgoCarPolicy,
"show_interface": False,
"image_observation": False,
"show_logo": False,
"no_traffic": False,
"no_static_vehicles": False,
"drivable_region_extension": 15,
"sequential_seed": True,
"reactive_traffic": True,
"show_fps": False,
"render_pipeline": True,
"daytime": "07:10",
"max_lateral_dist": 2,
"window_size": (1200, 800),
"camera_dist": 9,
"start_scenario_index": 5,
"num_scenarios": 4000,
"horizon": 1000,
"store_map": False,
"vehicle_config": dict(
show_navi_mark=True,
# no_wheel_friction=True,
use_special_color=False,
image_source="depth_camera",
lidar=dict(num_lasers=120, distance=50),
lane_line_detector=dict(num_lasers=0, distance=50),
side_detector=dict(num_lasers=0, distance=50)
),
}
)
)
# env.reset()
#
#
ckpt = "C:\\Users\\x1\\Desktop\\neurips_2023\\exp\\nuplan\\MultiWorkerPPO_ScenarioEnv_2f75c_00003_3_seed=300_2023-06-04_02-14-18\\checkpoint_430\\checkpoint-430"
policy = get_function(ckpt, True, config)
def capture():
env.capture("rgb_deluxe_{}_{}.jpg".format(env.current_seed, t))
# ret = env.render(
# mode="topdown", screen_size=(1600, 900), film_size=(10000, 10000), target_vehicle_heading_up=True
# )
# pygame.image.save(ret, "top_down_{}_{}.png".format(env.current_seed, env.episode_step))
#
#
# env.engine.accept("c", capture)
for i in range(10000):
o = env.reset()
for t in range(10000):
# capture()
o, r, d, info = env.step(policy(o)["default_policy"])
if d and info["arrive_dest"]:
break

View File

@@ -40,11 +40,11 @@ if __name__ == '__main__':
}
)
success = []
env.reset(force_seed=91)
env.reset(seed=91)
while True:
env.reset(force_seed=91)
env.reset(seed=91)
for t in range(10000):
o, r, d, info = env.step([0, 0])
o, r, d, _, info = env.step([0, 0])
assert env.observation_space.contains(o)
c_lane = env.vehicle.lane
long, lat, = c_lane.local_coordinates(env.vehicle.position)

View File

@@ -119,12 +119,12 @@ def loading_into_metadrive(
desc = "Scenarios: {}-{}".format(start_scenario_index, start_scenario_index + num_scenario)
for scenario_index in tqdm.tqdm(range(start_scenario_index, start_scenario_index + num_scenario), desc=desc):
try:
env.reset(force_seed=scenario_index)
env.reset(seed=scenario_index)
arrive = False
if random_drop and np.random.rand() < 0.5:
raise ValueError("Random Drop")
for _ in range(steps_to_run):
o, r, d, info = env.step([0, 0])
o, r, d, _, info = env.step([0, 0])
if d and info["arrive_dest"]:
arrive = True
assert arrive, "Can not arrive destination"

View File

@@ -1,6 +1,7 @@
import os.path
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.envs.gym_wrapper import GymEnvWrapper
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
@@ -13,8 +14,8 @@ if __name__ == '__main__':
stop = int(100_000_000)
config = dict(
env=env,
env_config=dict(
env=GymEnvWrapper,
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
# scenario
start_scenario_index=0,
num_scenarios=32,
@@ -33,8 +34,7 @@ if __name__ == '__main__':
# training
horizon=None,
use_lateral_reward=True,
),
)),
# # ===== Evaluation =====
evaluation_interval=2,

View File

@@ -1,13 +1,13 @@
import os.path
from metadrive.envs.gym_wrapper import GymEnvWrapper
from metadrive.envs.scenario_env import ScenarioEnv
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
config = dict(
env=ScenarioEnv,
env_config=dict(
env=GymEnvWrapper,
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
# scenario
start_scenario_index=0,
num_scenarios=40000,
@@ -42,7 +42,7 @@ config = dict(
vehicle_config=dict(side_detector=dict(num_lasers=0))
),
)),
# ===== Evaluation =====
evaluation_interval=15,

View File

@@ -1,14 +1,14 @@
import os.path
from ray.tune import grid_search
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.envs.gym_wrapper import GymEnvWrapper
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
config = dict(
env=ScenarioEnv,
env_config=dict(
env=GymEnvWrapper,
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
# scenario
start_scenario_index=0,
num_scenarios=40000,
@@ -41,7 +41,7 @@ config = dict(
vehicle_config=dict(side_detector=dict(num_lasers=0))
),
)),
# ===== Evaluation =====
evaluation_interval=15,

View File

@@ -1,13 +1,13 @@
import os.path
from metadrive.envs.gym_wrapper import GymEnvWrapper
from metadrive.envs.scenario_env import ScenarioEnv
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
config = dict(
env=ScenarioEnv,
env_config=dict(
env=GymEnvWrapper,
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
# scenario
start_scenario_index=0,
num_scenarios=40000,
@@ -41,7 +41,7 @@ config = dict(
vehicle_config=dict(side_detector=dict(num_lasers=0))
),
)),
# ===== Evaluation =====
evaluation_interval=15,
@@ -51,7 +51,8 @@ config = dict(
num_scenarios=1000,
sequential_seed=True,
curriculum_level=1, # turn off
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "waymo_test"))),
data_directory=os.path.join(SCENARIONET_DATASET_PATH,
"waymo_test"))),
evaluation_num_workers=10,
metrics_smoothing_episodes=10,

View File

@@ -1,11 +0,0 @@
from ray.rllib.utils import check_env
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.envs.gymnasium_wrapper import GymnasiumEnvWrapper
from gym import Env
if __name__ == '__main__':
env = GymnasiumEnvWrapper.build(ScenarioEnv)()
print(isinstance(ScenarioEnv, Env))
print(isinstance(env, Env))
print(env.observation_space)
check_env(env)

View File

@@ -9,6 +9,7 @@ import numpy as np
import tqdm
from metadrive.constants import TerminationState
from metadrive.envs.scenario_env import ScenarioEnv
from metadrive.envs.gym_wrapper import GymEnvWrapper
from ray import tune
from ray.tune import CLIReporter
@@ -291,7 +292,7 @@ def eval_ckpt(config,
episodes_to_evaluate_curriculum=num_scenarios,
data_directory=scenario_data_path,
use_render=render))
env = ScenarioEnv(env_config)
env = GymEnvWrapper(dict(inner_class=ScenarioEnv, inner_config=env_config))
super_data = defaultdict(list)
EPISODE_NUM = env.config["num_scenarios"]