新增scripts工具

This commit is contained in:
2025-10-25 21:44:11 +08:00
parent 62e638c4d2
commit c94571ddaa
17 changed files with 1193 additions and 66 deletions

304
dataset/expert_dataset.py Normal file
View File

@@ -0,0 +1,304 @@
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
sys.path.insert(0, os.path.join(project_root, "Env"))
import numpy as np
import torch
from torch.utils.data import Dataset
import pickle
from scenario_env import MultiAgentScenarioEnv
from metadrive.engine.asset_loader import AssetLoader
class DummyPolicy:
def act(self, *args, **kwargs):
return np.array([0.0, 0.0])
class ExpertTrajectoryDataset(Dataset):
"""
完整107维观测的专家轨迹数据集
"""
def __init__(self,
trajectory_data: dict,
observation_data: dict = None, # 可选的完整观测
sequence_length: int = 1,
extract_actions: bool = True):
"""
Args:
trajectory_data: 专家轨迹数据
observation_data: 完整107维观测数据(可选)
sequence_length: 序列长度
extract_actions: 是否提取动作
"""
self.trajectory_data = trajectory_data
self.observation_data = observation_data if observation_data else {}
self.sequence_length = sequence_length
self.extract_actions = extract_actions
# 构建索引
self.indices = []
for traj_id, traj in trajectory_data.items():
traj_len = traj["length"]
for start_idx in range(traj_len - sequence_length):
self.indices.append((traj_id, start_idx))
obs_dim = 107 if len(self.observation_data) > 0 else 5
print(f"专家数据集: {len(trajectory_data)} 条轨迹, "
f"{len(self.indices)} 个训练样本, 观测维度: {obs_dim}")
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
traj_id, start_idx = self.indices[idx]
traj = self.trajectory_data[traj_id]
end_idx = start_idx + self.sequence_length
# 如果有完整观测,使用完整观测(107维)
if traj_id in self.observation_data and len(self.observation_data[traj_id]) > 0:
obs_sequence = self.observation_data[traj_id]
states = obs_sequence[start_idx:end_idx] # (seq_len, 107)
else:
# 否则使用简化观测(5维)
positions = traj["positions"][start_idx:end_idx+1]
headings = traj["headings"][start_idx:end_idx+1]
velocities = traj["velocities"][start_idx:end_idx]
states = []
for i in range(self.sequence_length):
state = np.concatenate([
positions[i, :2], # x, y
velocities[i], # vx, vy
[headings[i]], # heading
])
states.append(state)
states = np.array(states)
if self.extract_actions:
positions = traj["positions"][start_idx:end_idx+1]
headings = traj["headings"][start_idx:end_idx+1]
velocities = traj["velocities"][start_idx:end_idx]
actions = self._extract_actions_from_states(
positions[:-1], positions[1:],
headings[:-1], headings[1:],
velocities
)
return torch.FloatTensor(states), torch.FloatTensor(actions)
else:
next_states = states[1:]
return torch.FloatTensor(states[:-1]), torch.FloatTensor(next_states)
def _extract_actions_from_states(self, pos_t, pos_t1, head_t, head_t1, vel_t):
"""从状态序列反推动作"""
actions = []
dt = 0.1
for i in range(len(pos_t)):
current_speed = np.linalg.norm(vel_t[i])
displacement = np.linalg.norm(pos_t1[i, :2] - pos_t[i, :2])
next_speed = displacement / dt
speed_change = (next_speed - current_speed) / dt
if speed_change >= 0:
throttle = np.clip(speed_change / 5.0, 0.0, 1.0)
else:
throttle = np.clip(speed_change / 8.0, -1.0, 0.0)
heading_change = head_t1[i] - head_t[i]
heading_change = np.arctan2(np.sin(heading_change), np.cos(heading_change))
steering = np.clip(heading_change / 0.2, -1.0, 1.0)
actions.append([throttle, steering])
return np.array(actions)
@staticmethod
def collect_with_full_obs(env_config, num_scenarios=10, save_path=None):
"""
✅ 使用env._get_all_obs()收集完整107维观测
这是正确的方法!直接利用环境已有的观测函数
"""
all_trajectories = {}
all_observations = {}
# 检查数据库
data_dir = env_config["config"]["data_directory"]
summary_path = os.path.join(data_dir, "dataset_summary.pkl")
with open(summary_path, 'rb') as f:
summary = pickle.load(f)
total_scenarios = len(summary)
print(f"数据库总场景数: {total_scenarios}")
if num_scenarios is None:
num_scenarios = total_scenarios
else:
num_scenarios = min(num_scenarios, total_scenarios)
print(f"计划收集(完整107维观测): {num_scenarios} 个场景")
for i in range(num_scenarios):
try:
# 创建环境
env = MultiAgentScenarioEnv(
config={
**env_config["config"],
"start_scenario_index": i,
"num_scenarios": 1,
},
agent2policy=env_config["agent2policy"]
)
# 重置环境
env.reset()
if not hasattr(env, 'expert_trajectories'):
print(f"⚠️ 场景 {i}: 缺少expert_trajectories")
env.close()
continue
expert_trajs = env.expert_trajectories
if len(expert_trajs) == 0:
print(f"⚠️ 场景 {i}: 无专家轨迹")
env.close()
continue
# 存储轨迹
scenario_id = env.engine.current_seed
for obj_id, traj in expert_trajs.items():
unique_id = f"scenario{i}_{obj_id}"
all_trajectories[unique_id] = traj
# ✅ 关键: 使用_get_all_obs()获取完整观测
# 创建agent_id到unique_id的映射
agent_to_unique = {}
for agent_id in env.controlled_agents.keys():
# 尝试匹配agent_id到expert_trajectories的obj_id
for obj_id in expert_trajs.keys():
if str(agent_id) in str(obj_id) or str(obj_id) in str(agent_id):
unique_id = f"scenario{i}_{obj_id}"
agent_to_unique[agent_id] = unique_id
all_observations[unique_id] = []
break
# 遍历场景的每一步,收集完整观测
max_steps = min([traj["length"] for traj in expert_trajs.values()])
for step in range(max_steps):
# ✅ 直接调用_get_all_obs()获取107维观测!
obs_list = env._get_all_obs()
# 存储每个agent的观测
for agent_idx, agent_id in enumerate(env.controlled_agents.keys()):
if agent_id in agent_to_unique:
unique_id = agent_to_unique[agent_id]
if agent_idx < len(obs_list):
# obs_list[agent_idx]已经是107维向量!
all_observations[unique_id].append(np.array(obs_list[agent_idx]))
# 执行零动作(保持场景状态)
actions = {aid: np.array([0.0, 0.0])
for aid in env.controlled_agents.keys()}
env.step(actions)
# 转换为numpy数组
for unique_id in list(all_observations.keys()):
if len(all_observations[unique_id]) > 0:
all_observations[unique_id] = np.array(all_observations[unique_id])
else:
del all_observations[unique_id]
env.close()
if (i + 1) % 5 == 0:
print(f"✓ 已收集 {i+1}/{num_scenarios}, "
f"轨迹: {len(all_trajectories)}, "
f"观测: {len(all_observations)}")
except Exception as e:
print(f"✗ 场景 {i} 收集失败: {e}")
import traceback
traceback.print_exc()
try:
env.close()
except:
pass
continue
print(f"\n收集完成!")
print(f" 轨迹数: {len(all_trajectories)}")
print(f" 完整观测数: {len(all_observations)}")
# 验证观测维度
if len(all_observations) > 0:
sample_obs = list(all_observations.values())[0]
if len(sample_obs) > 0:
obs_dim = len(sample_obs[0])
print(f" 观测维度: {obs_dim} (应为107)")
if save_path:
with open(save_path, "wb") as f:
pickle.dump({
"trajectories": all_trajectories,
"observations": all_observations
}, f)
print(f"数据已保存到: {save_path}")
return all_trajectories, all_observations
if __name__ == "__main__":
WAYMO_DATA_DIR = r"/home/huangfukk/mdsn"
data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False)
env_config = {
"config": {
"data_directory": data_dir,
"is_multi_agent": True,
"num_controlled_agents": 3,
"use_render": False,
"sequential_seed": True,
},
"agent2policy": DummyPolicy()
}
print("=" * 60)
print("选择收集模式:")
print("1. 简化观测(5维) - 快速,已验证 ✅")
print("2. 完整观测(107维) - 使用_get_all_obs() ⭐")
print("=" * 60)
mode = input("请选择模式(1或2,默认1): ").strip() or "1"
if mode == "2":
print("\n开始收集完整107维观测...")
trajectories, observations = ExpertTrajectoryDataset.collect_with_full_obs(
env_config,
num_scenarios=10,
save_path="./expert_trajectories_full.pkl"
)
if len(trajectories) > 0:
dataset = ExpertTrajectoryDataset(
trajectories,
observations,
sequence_length=1
)
state, action = dataset[0]
print(f"\n数据集测试:")
print(f" 总轨迹数: {len(trajectories)}")
print(f" 总观测数: {len(observations)}")
print(f" 训练样本数: {len(dataset)}")
print(f" 状态维度: {state.shape}")
print(f" 动作维度: {action.shape}")
else:
print("\n开始收集简化5维观测...")
# 保持原有的简化版本代码...
print("(使用之前已成功的方法)")