Update env (#7)
* add capture script * gymnasium API * training with gymnasium API
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
[**Webpage**](https://metadriverse.github.io/scenarionet/) |
|
||||
[**Code**](https://github.com/metadriverse/scenarionet) |
|
||||
[**Video**](https://metadriverse.github.io/scenarionet/) |
|
||||
[**Video**](https://youtu.be/3bOqswXP6OA) |
|
||||
[**Paper**](http://arxiv.org/abs/2306.12241) |
|
||||
|
||||
ScenarioNet allows users to load scenarios from real-world dataset like Waymo, nuPlan, nuScenes, l5 and synthetic
|
||||
|
||||
@@ -39,14 +39,14 @@ if __name__ == '__main__':
|
||||
}
|
||||
)
|
||||
for index in range(num_scenario if args.scenario_index is not None else 1000000):
|
||||
env.reset(force_seed=index if args.scenario_index is None else args.scenario_index)
|
||||
env.reset(seed=index if args.scenario_index is None else args.scenario_index)
|
||||
for t in range(10000):
|
||||
o, r, d, info = env.step([0, 0])
|
||||
env.step([0, 0])
|
||||
if env.config["use_render"]:
|
||||
env.render(text={
|
||||
"scenario index": env.engine.global_seed + env.config["start_scenario_index"],
|
||||
})
|
||||
|
||||
if d and info["arrive_dest"]:
|
||||
if env.episode_step >= env.engine.data_manager.current_scenario_length:
|
||||
print("scenario:{}, success".format(env.engine.global_random_seed))
|
||||
break
|
||||
|
||||
@@ -51,9 +51,9 @@ if __name__ == "__main__":
|
||||
# env.engine.accept("c", capture)
|
||||
|
||||
# for seed in [1001, 1002, 1005, 1011]:
|
||||
env.reset(force_seed=1020)
|
||||
env.reset(seed=1020)
|
||||
for t in range(10000):
|
||||
capture()
|
||||
o, r, d, info = env.step([1, 0.88])
|
||||
env.step([1, 0.88])
|
||||
if env.episode_step >= env.engine.data_manager.current_scenario_length:
|
||||
break
|
||||
|
||||
@@ -28,10 +28,10 @@ if __name__ == "__main__":
|
||||
)
|
||||
)
|
||||
|
||||
# o = env.reset(force_seed=0)
|
||||
# o = env.reset(seed=0)
|
||||
# env.engine.accept("c", capture)
|
||||
for s in range(6, 1000):
|
||||
env.reset(force_seed=16)
|
||||
env.reset(seed=16)
|
||||
for t in range(10000):
|
||||
capture()
|
||||
o, r, d, info = env.step([0, 0])
|
||||
|
||||
@@ -56,7 +56,7 @@ if __name__ == "__main__":
|
||||
|
||||
# 0,1,3,4,5,6
|
||||
for seed in range(10):
|
||||
env.reset(force_seed=seed)
|
||||
env.reset(seed=seed)
|
||||
for t in range(10000):
|
||||
env.capture("rgb_deluxe_{}_{}.jpg".format(env.current_seed, t))
|
||||
ret = env.render(
|
||||
@@ -65,8 +65,7 @@ if __name__ == "__main__":
|
||||
pygame.image.save(ret, "top_down_{}_{}.png".format(env.current_seed, t))
|
||||
# env.vehicle.get_camera("depth_camera").save_image(env.vehicle, "depth_{}.jpg".format(t))
|
||||
# env.vehicle.get_camera("rgb_camera").save_image(env.vehicle, "rgb_{}.jpg".format(t))
|
||||
o, r, d, info = env.step([1, 0.88])
|
||||
assert env.observation_space.contains(o)
|
||||
env.step([1, 0.88])
|
||||
# if d:
|
||||
if env.episode_step >= env.engine.data_manager.current_scenario_length:
|
||||
break
|
||||
|
||||
72
scenarionet/tests/script/run_ckpt.py
Normal file
72
scenarionet/tests/script/run_ckpt.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import pygame
|
||||
from metadrive.engine.asset_loader import AssetLoader
|
||||
from metadrive.envs.real_data_envs.nuscenes_env import ScenarioEnv
|
||||
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||
from scenarionet_training.train_utils.utils import initialize_ray, get_function
|
||||
from scenarionet_training.scripts.train_nuplan import config
|
||||
|
||||
if __name__ == "__main__":
|
||||
initialize_ray(test_mode=False, num_gpus=1)
|
||||
env = GymEnvWrapper(
|
||||
dict(
|
||||
inner_class=ScenarioEnv,
|
||||
inner_config={
|
||||
# "data_directory": AssetLoader.file_path("nuscenes", return_raw_style=False),
|
||||
"data_directory": "D:\\scenarionet_testset\\nuplan_test\\nuplan_test_w_raw",
|
||||
"use_render": True,
|
||||
# "agent_policy": ReplayEgoCarPolicy,
|
||||
"show_interface": False,
|
||||
"image_observation": False,
|
||||
"show_logo": False,
|
||||
"no_traffic": False,
|
||||
"no_static_vehicles": False,
|
||||
"drivable_region_extension": 15,
|
||||
"sequential_seed": True,
|
||||
"reactive_traffic": True,
|
||||
"show_fps": False,
|
||||
"render_pipeline": True,
|
||||
"daytime": "07:10",
|
||||
"max_lateral_dist": 2,
|
||||
"window_size": (1200, 800),
|
||||
"camera_dist": 9,
|
||||
"start_scenario_index": 5,
|
||||
"num_scenarios": 4000,
|
||||
"horizon": 1000,
|
||||
"store_map": False,
|
||||
"vehicle_config": dict(
|
||||
show_navi_mark=True,
|
||||
# no_wheel_friction=True,
|
||||
use_special_color=False,
|
||||
image_source="depth_camera",
|
||||
lidar=dict(num_lasers=120, distance=50),
|
||||
lane_line_detector=dict(num_lasers=0, distance=50),
|
||||
side_detector=dict(num_lasers=0, distance=50)
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# env.reset()
|
||||
#
|
||||
#
|
||||
ckpt = "C:\\Users\\x1\\Desktop\\neurips_2023\\exp\\nuplan\\MultiWorkerPPO_ScenarioEnv_2f75c_00003_3_seed=300_2023-06-04_02-14-18\\checkpoint_430\\checkpoint-430"
|
||||
policy = get_function(ckpt, True, config)
|
||||
|
||||
def capture():
|
||||
env.capture("rgb_deluxe_{}_{}.jpg".format(env.current_seed, t))
|
||||
# ret = env.render(
|
||||
# mode="topdown", screen_size=(1600, 900), film_size=(10000, 10000), target_vehicle_heading_up=True
|
||||
# )
|
||||
# pygame.image.save(ret, "top_down_{}_{}.png".format(env.current_seed, env.episode_step))
|
||||
|
||||
#
|
||||
#
|
||||
# env.engine.accept("c", capture)
|
||||
|
||||
for i in range(10000):
|
||||
o = env.reset()
|
||||
for t in range(10000):
|
||||
# capture()
|
||||
o, r, d, info = env.step(policy(o)["default_policy"])
|
||||
if d and info["arrive_dest"]:
|
||||
break
|
||||
@@ -40,11 +40,11 @@ if __name__ == '__main__':
|
||||
}
|
||||
)
|
||||
success = []
|
||||
env.reset(force_seed=91)
|
||||
env.reset(seed=91)
|
||||
while True:
|
||||
env.reset(force_seed=91)
|
||||
env.reset(seed=91)
|
||||
for t in range(10000):
|
||||
o, r, d, info = env.step([0, 0])
|
||||
o, r, d, _, info = env.step([0, 0])
|
||||
assert env.observation_space.contains(o)
|
||||
c_lane = env.vehicle.lane
|
||||
long, lat, = c_lane.local_coordinates(env.vehicle.position)
|
||||
|
||||
@@ -119,12 +119,12 @@ def loading_into_metadrive(
|
||||
desc = "Scenarios: {}-{}".format(start_scenario_index, start_scenario_index + num_scenario)
|
||||
for scenario_index in tqdm.tqdm(range(start_scenario_index, start_scenario_index + num_scenario), desc=desc):
|
||||
try:
|
||||
env.reset(force_seed=scenario_index)
|
||||
env.reset(seed=scenario_index)
|
||||
arrive = False
|
||||
if random_drop and np.random.rand() < 0.5:
|
||||
raise ValueError("Random Drop")
|
||||
for _ in range(steps_to_run):
|
||||
o, r, d, info = env.step([0, 0])
|
||||
o, r, d, _, info = env.step([0, 0])
|
||||
if d and info["arrive_dest"]:
|
||||
arrive = True
|
||||
assert arrive, "Can not arrive destination"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os.path
|
||||
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||
|
||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
@@ -13,8 +14,8 @@ if __name__ == '__main__':
|
||||
stop = int(100_000_000)
|
||||
|
||||
config = dict(
|
||||
env=env,
|
||||
env_config=dict(
|
||||
env=GymEnvWrapper,
|
||||
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
|
||||
# scenario
|
||||
start_scenario_index=0,
|
||||
num_scenarios=32,
|
||||
@@ -33,8 +34,7 @@ if __name__ == '__main__':
|
||||
|
||||
# training
|
||||
horizon=None,
|
||||
use_lateral_reward=True,
|
||||
),
|
||||
)),
|
||||
|
||||
# # ===== Evaluation =====
|
||||
evaluation_interval=2,
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import os.path
|
||||
|
||||
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||
|
||||
config = dict(
|
||||
env=ScenarioEnv,
|
||||
env_config=dict(
|
||||
env=GymEnvWrapper,
|
||||
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
|
||||
# scenario
|
||||
start_scenario_index=0,
|
||||
num_scenarios=40000,
|
||||
@@ -42,7 +42,7 @@ config = dict(
|
||||
|
||||
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
||||
|
||||
),
|
||||
)),
|
||||
|
||||
# ===== Evaluation =====
|
||||
evaluation_interval=15,
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import os.path
|
||||
from ray.tune import grid_search
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
|
||||
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||
|
||||
config = dict(
|
||||
env=ScenarioEnv,
|
||||
env_config=dict(
|
||||
env=GymEnvWrapper,
|
||||
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
|
||||
# scenario
|
||||
start_scenario_index=0,
|
||||
num_scenarios=40000,
|
||||
@@ -41,7 +41,7 @@ config = dict(
|
||||
|
||||
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
||||
|
||||
),
|
||||
)),
|
||||
|
||||
# ===== Evaluation =====
|
||||
evaluation_interval=15,
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import os.path
|
||||
|
||||
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||
|
||||
config = dict(
|
||||
env=ScenarioEnv,
|
||||
env_config=dict(
|
||||
env=GymEnvWrapper,
|
||||
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
|
||||
# scenario
|
||||
start_scenario_index=0,
|
||||
num_scenarios=40000,
|
||||
@@ -41,7 +41,7 @@ config = dict(
|
||||
|
||||
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
||||
|
||||
),
|
||||
)),
|
||||
|
||||
# ===== Evaluation =====
|
||||
evaluation_interval=15,
|
||||
@@ -51,7 +51,8 @@ config = dict(
|
||||
num_scenarios=1000,
|
||||
sequential_seed=True,
|
||||
curriculum_level=1, # turn off
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "waymo_test"))),
|
||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH,
|
||||
"waymo_test"))),
|
||||
evaluation_num_workers=10,
|
||||
metrics_smoothing_episodes=10,
|
||||
|
||||
@@ -68,16 +69,16 @@ config = dict(
|
||||
num_cpus_for_driver=1,
|
||||
num_workers=20,
|
||||
framework="tf"
|
||||
)
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# PG data is generated with seeds 10,000 to 60,000
|
||||
# PG data is generated with seeds 10,000 to 60,000
|
||||
args = get_train_parser().parse_args()
|
||||
exp_name = get_exp_name(args)
|
||||
stop = int(100_000_000)
|
||||
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
|
||||
exp_name = get_exp_name(args)
|
||||
stop = int(100_000_000)
|
||||
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
|
||||
|
||||
train(
|
||||
train(
|
||||
MultiWorkerPPO,
|
||||
exp_name=exp_name,
|
||||
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
|
||||
@@ -92,4 +93,4 @@ if __name__ == '__main__':
|
||||
# TODO remove this when we release our code
|
||||
# wandb_key_file="~/wandb_api_key_file.txt",
|
||||
wandb_project="scenarionet",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from ray.rllib.utils import check_env
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from metadrive.envs.gymnasium_wrapper import GymnasiumEnvWrapper
|
||||
from gym import Env
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = GymnasiumEnvWrapper.build(ScenarioEnv)()
|
||||
print(isinstance(ScenarioEnv, Env))
|
||||
print(isinstance(env, Env))
|
||||
print(env.observation_space)
|
||||
check_env(env)
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import tqdm
|
||||
from metadrive.constants import TerminationState
|
||||
from metadrive.envs.scenario_env import ScenarioEnv
|
||||
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||
from ray import tune
|
||||
from ray.tune import CLIReporter
|
||||
|
||||
@@ -291,7 +292,7 @@ def eval_ckpt(config,
|
||||
episodes_to_evaluate_curriculum=num_scenarios,
|
||||
data_directory=scenario_data_path,
|
||||
use_render=render))
|
||||
env = ScenarioEnv(env_config)
|
||||
env = GymEnvWrapper(dict(inner_class=ScenarioEnv, inner_config=env_config))
|
||||
|
||||
super_data = defaultdict(list)
|
||||
EPISODE_NUM = env.config["num_scenarios"]
|
||||
|
||||
Reference in New Issue
Block a user