Update env (#7)
* add capture script * gymnasium API * training with gymnasium API
This commit is contained in:
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
[**Webpage**](https://metadriverse.github.io/scenarionet/) |
|
[**Webpage**](https://metadriverse.github.io/scenarionet/) |
|
||||||
[**Code**](https://github.com/metadriverse/scenarionet) |
|
[**Code**](https://github.com/metadriverse/scenarionet) |
|
||||||
[**Video**](https://metadriverse.github.io/scenarionet/) |
|
[**Video**](https://youtu.be/3bOqswXP6OA) |
|
||||||
[**Paper**](http://arxiv.org/abs/2306.12241) |
|
[**Paper**](http://arxiv.org/abs/2306.12241) |
|
||||||
|
|
||||||
ScenarioNet allows users to load scenarios from real-world dataset like Waymo, nuPlan, nuScenes, l5 and synthetic
|
ScenarioNet allows users to load scenarios from real-world dataset like Waymo, nuPlan, nuScenes, l5 and synthetic
|
||||||
|
|||||||
@@ -39,14 +39,14 @@ if __name__ == '__main__':
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
for index in range(num_scenario if args.scenario_index is not None else 1000000):
|
for index in range(num_scenario if args.scenario_index is not None else 1000000):
|
||||||
env.reset(force_seed=index if args.scenario_index is None else args.scenario_index)
|
env.reset(seed=index if args.scenario_index is None else args.scenario_index)
|
||||||
for t in range(10000):
|
for t in range(10000):
|
||||||
o, r, d, info = env.step([0, 0])
|
env.step([0, 0])
|
||||||
if env.config["use_render"]:
|
if env.config["use_render"]:
|
||||||
env.render(text={
|
env.render(text={
|
||||||
"scenario index": env.engine.global_seed + env.config["start_scenario_index"],
|
"scenario index": env.engine.global_seed + env.config["start_scenario_index"],
|
||||||
})
|
})
|
||||||
|
|
||||||
if d and info["arrive_dest"]:
|
if env.episode_step >= env.engine.data_manager.current_scenario_length:
|
||||||
print("scenario:{}, success".format(env.engine.global_random_seed))
|
print("scenario:{}, success".format(env.engine.global_random_seed))
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -51,9 +51,9 @@ if __name__ == "__main__":
|
|||||||
# env.engine.accept("c", capture)
|
# env.engine.accept("c", capture)
|
||||||
|
|
||||||
# for seed in [1001, 1002, 1005, 1011]:
|
# for seed in [1001, 1002, 1005, 1011]:
|
||||||
env.reset(force_seed=1020)
|
env.reset(seed=1020)
|
||||||
for t in range(10000):
|
for t in range(10000):
|
||||||
capture()
|
capture()
|
||||||
o, r, d, info = env.step([1, 0.88])
|
env.step([1, 0.88])
|
||||||
if env.episode_step >= env.engine.data_manager.current_scenario_length:
|
if env.episode_step >= env.engine.data_manager.current_scenario_length:
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -28,10 +28,10 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# o = env.reset(force_seed=0)
|
# o = env.reset(seed=0)
|
||||||
# env.engine.accept("c", capture)
|
# env.engine.accept("c", capture)
|
||||||
for s in range(6, 1000):
|
for s in range(6, 1000):
|
||||||
env.reset(force_seed=16)
|
env.reset(seed=16)
|
||||||
for t in range(10000):
|
for t in range(10000):
|
||||||
capture()
|
capture()
|
||||||
o, r, d, info = env.step([0, 0])
|
o, r, d, info = env.step([0, 0])
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 0,1,3,4,5,6
|
# 0,1,3,4,5,6
|
||||||
for seed in range(10):
|
for seed in range(10):
|
||||||
env.reset(force_seed=seed)
|
env.reset(seed=seed)
|
||||||
for t in range(10000):
|
for t in range(10000):
|
||||||
env.capture("rgb_deluxe_{}_{}.jpg".format(env.current_seed, t))
|
env.capture("rgb_deluxe_{}_{}.jpg".format(env.current_seed, t))
|
||||||
ret = env.render(
|
ret = env.render(
|
||||||
@@ -65,8 +65,7 @@ if __name__ == "__main__":
|
|||||||
pygame.image.save(ret, "top_down_{}_{}.png".format(env.current_seed, t))
|
pygame.image.save(ret, "top_down_{}_{}.png".format(env.current_seed, t))
|
||||||
# env.vehicle.get_camera("depth_camera").save_image(env.vehicle, "depth_{}.jpg".format(t))
|
# env.vehicle.get_camera("depth_camera").save_image(env.vehicle, "depth_{}.jpg".format(t))
|
||||||
# env.vehicle.get_camera("rgb_camera").save_image(env.vehicle, "rgb_{}.jpg".format(t))
|
# env.vehicle.get_camera("rgb_camera").save_image(env.vehicle, "rgb_{}.jpg".format(t))
|
||||||
o, r, d, info = env.step([1, 0.88])
|
env.step([1, 0.88])
|
||||||
assert env.observation_space.contains(o)
|
|
||||||
# if d:
|
# if d:
|
||||||
if env.episode_step >= env.engine.data_manager.current_scenario_length:
|
if env.episode_step >= env.engine.data_manager.current_scenario_length:
|
||||||
break
|
break
|
||||||
|
|||||||
72
scenarionet/tests/script/run_ckpt.py
Normal file
72
scenarionet/tests/script/run_ckpt.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import pygame
|
||||||
|
from metadrive.engine.asset_loader import AssetLoader
|
||||||
|
from metadrive.envs.real_data_envs.nuscenes_env import ScenarioEnv
|
||||||
|
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||||
|
from scenarionet_training.train_utils.utils import initialize_ray, get_function
|
||||||
|
from scenarionet_training.scripts.train_nuplan import config
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
initialize_ray(test_mode=False, num_gpus=1)
|
||||||
|
env = GymEnvWrapper(
|
||||||
|
dict(
|
||||||
|
inner_class=ScenarioEnv,
|
||||||
|
inner_config={
|
||||||
|
# "data_directory": AssetLoader.file_path("nuscenes", return_raw_style=False),
|
||||||
|
"data_directory": "D:\\scenarionet_testset\\nuplan_test\\nuplan_test_w_raw",
|
||||||
|
"use_render": True,
|
||||||
|
# "agent_policy": ReplayEgoCarPolicy,
|
||||||
|
"show_interface": False,
|
||||||
|
"image_observation": False,
|
||||||
|
"show_logo": False,
|
||||||
|
"no_traffic": False,
|
||||||
|
"no_static_vehicles": False,
|
||||||
|
"drivable_region_extension": 15,
|
||||||
|
"sequential_seed": True,
|
||||||
|
"reactive_traffic": True,
|
||||||
|
"show_fps": False,
|
||||||
|
"render_pipeline": True,
|
||||||
|
"daytime": "07:10",
|
||||||
|
"max_lateral_dist": 2,
|
||||||
|
"window_size": (1200, 800),
|
||||||
|
"camera_dist": 9,
|
||||||
|
"start_scenario_index": 5,
|
||||||
|
"num_scenarios": 4000,
|
||||||
|
"horizon": 1000,
|
||||||
|
"store_map": False,
|
||||||
|
"vehicle_config": dict(
|
||||||
|
show_navi_mark=True,
|
||||||
|
# no_wheel_friction=True,
|
||||||
|
use_special_color=False,
|
||||||
|
image_source="depth_camera",
|
||||||
|
lidar=dict(num_lasers=120, distance=50),
|
||||||
|
lane_line_detector=dict(num_lasers=0, distance=50),
|
||||||
|
side_detector=dict(num_lasers=0, distance=50)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# env.reset()
|
||||||
|
#
|
||||||
|
#
|
||||||
|
ckpt = "C:\\Users\\x1\\Desktop\\neurips_2023\\exp\\nuplan\\MultiWorkerPPO_ScenarioEnv_2f75c_00003_3_seed=300_2023-06-04_02-14-18\\checkpoint_430\\checkpoint-430"
|
||||||
|
policy = get_function(ckpt, True, config)
|
||||||
|
|
||||||
|
def capture():
|
||||||
|
env.capture("rgb_deluxe_{}_{}.jpg".format(env.current_seed, t))
|
||||||
|
# ret = env.render(
|
||||||
|
# mode="topdown", screen_size=(1600, 900), film_size=(10000, 10000), target_vehicle_heading_up=True
|
||||||
|
# )
|
||||||
|
# pygame.image.save(ret, "top_down_{}_{}.png".format(env.current_seed, env.episode_step))
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# env.engine.accept("c", capture)
|
||||||
|
|
||||||
|
for i in range(10000):
|
||||||
|
o = env.reset()
|
||||||
|
for t in range(10000):
|
||||||
|
# capture()
|
||||||
|
o, r, d, info = env.step(policy(o)["default_policy"])
|
||||||
|
if d and info["arrive_dest"]:
|
||||||
|
break
|
||||||
@@ -40,11 +40,11 @@ if __name__ == '__main__':
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
success = []
|
success = []
|
||||||
env.reset(force_seed=91)
|
env.reset(seed=91)
|
||||||
while True:
|
while True:
|
||||||
env.reset(force_seed=91)
|
env.reset(seed=91)
|
||||||
for t in range(10000):
|
for t in range(10000):
|
||||||
o, r, d, info = env.step([0, 0])
|
o, r, d, _, info = env.step([0, 0])
|
||||||
assert env.observation_space.contains(o)
|
assert env.observation_space.contains(o)
|
||||||
c_lane = env.vehicle.lane
|
c_lane = env.vehicle.lane
|
||||||
long, lat, = c_lane.local_coordinates(env.vehicle.position)
|
long, lat, = c_lane.local_coordinates(env.vehicle.position)
|
||||||
|
|||||||
@@ -119,12 +119,12 @@ def loading_into_metadrive(
|
|||||||
desc = "Scenarios: {}-{}".format(start_scenario_index, start_scenario_index + num_scenario)
|
desc = "Scenarios: {}-{}".format(start_scenario_index, start_scenario_index + num_scenario)
|
||||||
for scenario_index in tqdm.tqdm(range(start_scenario_index, start_scenario_index + num_scenario), desc=desc):
|
for scenario_index in tqdm.tqdm(range(start_scenario_index, start_scenario_index + num_scenario), desc=desc):
|
||||||
try:
|
try:
|
||||||
env.reset(force_seed=scenario_index)
|
env.reset(seed=scenario_index)
|
||||||
arrive = False
|
arrive = False
|
||||||
if random_drop and np.random.rand() < 0.5:
|
if random_drop and np.random.rand() < 0.5:
|
||||||
raise ValueError("Random Drop")
|
raise ValueError("Random Drop")
|
||||||
for _ in range(steps_to_run):
|
for _ in range(steps_to_run):
|
||||||
o, r, d, info = env.step([0, 0])
|
o, r, d, _, info = env.step([0, 0])
|
||||||
if d and info["arrive_dest"]:
|
if d and info["arrive_dest"]:
|
||||||
arrive = True
|
arrive = True
|
||||||
assert arrive, "Can not arrive destination"
|
assert arrive, "Can not arrive destination"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
from metadrive.envs.scenario_env import ScenarioEnv
|
from metadrive.envs.scenario_env import ScenarioEnv
|
||||||
|
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||||
|
|
||||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||||
@@ -13,8 +14,8 @@ if __name__ == '__main__':
|
|||||||
stop = int(100_000_000)
|
stop = int(100_000_000)
|
||||||
|
|
||||||
config = dict(
|
config = dict(
|
||||||
env=env,
|
env=GymEnvWrapper,
|
||||||
env_config=dict(
|
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
|
||||||
# scenario
|
# scenario
|
||||||
start_scenario_index=0,
|
start_scenario_index=0,
|
||||||
num_scenarios=32,
|
num_scenarios=32,
|
||||||
@@ -33,8 +34,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# training
|
# training
|
||||||
horizon=None,
|
horizon=None,
|
||||||
use_lateral_reward=True,
|
)),
|
||||||
),
|
|
||||||
|
|
||||||
# # ===== Evaluation =====
|
# # ===== Evaluation =====
|
||||||
evaluation_interval=2,
|
evaluation_interval=2,
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import os.path
|
import os.path
|
||||||
|
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||||
from metadrive.envs.scenario_env import ScenarioEnv
|
from metadrive.envs.scenario_env import ScenarioEnv
|
||||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||||
|
|
||||||
config = dict(
|
config = dict(
|
||||||
env=ScenarioEnv,
|
env=GymEnvWrapper,
|
||||||
env_config=dict(
|
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
|
||||||
# scenario
|
# scenario
|
||||||
start_scenario_index=0,
|
start_scenario_index=0,
|
||||||
num_scenarios=40000,
|
num_scenarios=40000,
|
||||||
@@ -42,7 +42,7 @@ config = dict(
|
|||||||
|
|
||||||
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
||||||
|
|
||||||
),
|
)),
|
||||||
|
|
||||||
# ===== Evaluation =====
|
# ===== Evaluation =====
|
||||||
evaluation_interval=15,
|
evaluation_interval=15,
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import os.path
|
import os.path
|
||||||
from ray.tune import grid_search
|
from ray.tune import grid_search
|
||||||
from metadrive.envs.scenario_env import ScenarioEnv
|
from metadrive.envs.scenario_env import ScenarioEnv
|
||||||
|
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||||
|
|
||||||
config = dict(
|
config = dict(
|
||||||
env=ScenarioEnv,
|
env=GymEnvWrapper,
|
||||||
env_config=dict(
|
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
|
||||||
# scenario
|
# scenario
|
||||||
start_scenario_index=0,
|
start_scenario_index=0,
|
||||||
num_scenarios=40000,
|
num_scenarios=40000,
|
||||||
@@ -41,7 +41,7 @@ config = dict(
|
|||||||
|
|
||||||
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
||||||
|
|
||||||
),
|
)),
|
||||||
|
|
||||||
# ===== Evaluation =====
|
# ===== Evaluation =====
|
||||||
evaluation_interval=15,
|
evaluation_interval=15,
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import os.path
|
import os.path
|
||||||
|
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||||
from metadrive.envs.scenario_env import ScenarioEnv
|
from metadrive.envs.scenario_env import ScenarioEnv
|
||||||
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
from scenarionet import SCENARIONET_REPO_PATH, SCENARIONET_DATASET_PATH
|
||||||
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
from scenarionet_training.train_utils.multi_worker_PPO import MultiWorkerPPO
|
||||||
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
from scenarionet_training.train_utils.utils import train, get_train_parser, get_exp_name
|
||||||
|
|
||||||
config = dict(
|
config = dict(
|
||||||
env=ScenarioEnv,
|
env=GymEnvWrapper,
|
||||||
env_config=dict(
|
env_config=dict(inner_class=ScenarioEnv, inner_config=dict(
|
||||||
# scenario
|
# scenario
|
||||||
start_scenario_index=0,
|
start_scenario_index=0,
|
||||||
num_scenarios=40000,
|
num_scenarios=40000,
|
||||||
@@ -41,7 +41,7 @@ config = dict(
|
|||||||
|
|
||||||
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
vehicle_config=dict(side_detector=dict(num_lasers=0))
|
||||||
|
|
||||||
),
|
)),
|
||||||
|
|
||||||
# ===== Evaluation =====
|
# ===== Evaluation =====
|
||||||
evaluation_interval=15,
|
evaluation_interval=15,
|
||||||
@@ -51,7 +51,8 @@ config = dict(
|
|||||||
num_scenarios=1000,
|
num_scenarios=1000,
|
||||||
sequential_seed=True,
|
sequential_seed=True,
|
||||||
curriculum_level=1, # turn off
|
curriculum_level=1, # turn off
|
||||||
data_directory=os.path.join(SCENARIONET_DATASET_PATH, "waymo_test"))),
|
data_directory=os.path.join(SCENARIONET_DATASET_PATH,
|
||||||
|
"waymo_test"))),
|
||||||
evaluation_num_workers=10,
|
evaluation_num_workers=10,
|
||||||
metrics_smoothing_episodes=10,
|
metrics_smoothing_episodes=10,
|
||||||
|
|
||||||
@@ -68,16 +69,16 @@ config = dict(
|
|||||||
num_cpus_for_driver=1,
|
num_cpus_for_driver=1,
|
||||||
num_workers=20,
|
num_workers=20,
|
||||||
framework="tf"
|
framework="tf"
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# PG data is generated with seeds 10,000 to 60,000
|
# PG data is generated with seeds 10,000 to 60,000
|
||||||
args = get_train_parser().parse_args()
|
args = get_train_parser().parse_args()
|
||||||
exp_name = get_exp_name(args)
|
exp_name = get_exp_name(args)
|
||||||
stop = int(100_000_000)
|
stop = int(100_000_000)
|
||||||
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
|
config["num_gpus"] = 0.5 if args.num_gpus != 0 else 0
|
||||||
|
|
||||||
train(
|
train(
|
||||||
MultiWorkerPPO,
|
MultiWorkerPPO,
|
||||||
exp_name=exp_name,
|
exp_name=exp_name,
|
||||||
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
|
save_dir=os.path.join(SCENARIONET_REPO_PATH, "experiment"),
|
||||||
@@ -92,4 +93,4 @@ if __name__ == '__main__':
|
|||||||
# TODO remove this when we release our code
|
# TODO remove this when we release our code
|
||||||
# wandb_key_file="~/wandb_api_key_file.txt",
|
# wandb_key_file="~/wandb_api_key_file.txt",
|
||||||
wandb_project="scenarionet",
|
wandb_project="scenarionet",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
from ray.rllib.utils import check_env
|
|
||||||
from metadrive.envs.scenario_env import ScenarioEnv
|
|
||||||
from metadrive.envs.gymnasium_wrapper import GymnasiumEnvWrapper
|
|
||||||
from gym import Env
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
env = GymnasiumEnvWrapper.build(ScenarioEnv)()
|
|
||||||
print(isinstance(ScenarioEnv, Env))
|
|
||||||
print(isinstance(env, Env))
|
|
||||||
print(env.observation_space)
|
|
||||||
check_env(env)
|
|
||||||
@@ -9,6 +9,7 @@ import numpy as np
|
|||||||
import tqdm
|
import tqdm
|
||||||
from metadrive.constants import TerminationState
|
from metadrive.constants import TerminationState
|
||||||
from metadrive.envs.scenario_env import ScenarioEnv
|
from metadrive.envs.scenario_env import ScenarioEnv
|
||||||
|
from metadrive.envs.gym_wrapper import GymEnvWrapper
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.tune import CLIReporter
|
from ray.tune import CLIReporter
|
||||||
|
|
||||||
@@ -291,7 +292,7 @@ def eval_ckpt(config,
|
|||||||
episodes_to_evaluate_curriculum=num_scenarios,
|
episodes_to_evaluate_curriculum=num_scenarios,
|
||||||
data_directory=scenario_data_path,
|
data_directory=scenario_data_path,
|
||||||
use_render=render))
|
use_render=render))
|
||||||
env = ScenarioEnv(env_config)
|
env = GymEnvWrapper(dict(inner_class=ScenarioEnv, inner_config=env_config))
|
||||||
|
|
||||||
super_data = defaultdict(list)
|
super_data = defaultdict(list)
|
||||||
EPISODE_NUM = env.config["num_scenarios"]
|
EPISODE_NUM = env.config["num_scenarios"]
|
||||||
|
|||||||
Reference in New Issue
Block a user