import sys import os current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(current_dir) sys.path.insert(0, os.path.join(project_root, "Env")) import numpy as np import torch from torch.utils.data import Dataset import pickle from scenario_env import MultiAgentScenarioEnv from metadrive.engine.asset_loader import AssetLoader class DummyPolicy: def act(self, *args, **kwargs): return np.array([0.0, 0.0]) class ExpertTrajectoryDataset(Dataset): """ 完整107维观测的专家轨迹数据集 """ def __init__(self, trajectory_data: dict, observation_data: dict = None, # 可选的完整观测 sequence_length: int = 1, extract_actions: bool = True): """ Args: trajectory_data: 专家轨迹数据 observation_data: 完整107维观测数据(可选) sequence_length: 序列长度 extract_actions: 是否提取动作 """ self.trajectory_data = trajectory_data self.observation_data = observation_data if observation_data else {} self.sequence_length = sequence_length self.extract_actions = extract_actions # 构建索引 self.indices = [] for traj_id, traj in trajectory_data.items(): traj_len = traj["length"] for start_idx in range(traj_len - sequence_length): self.indices.append((traj_id, start_idx)) obs_dim = 107 if len(self.observation_data) > 0 else 5 print(f"专家数据集: {len(trajectory_data)} 条轨迹, " f"{len(self.indices)} 个训练样本, 观测维度: {obs_dim}") def __len__(self): return len(self.indices) def __getitem__(self, idx): traj_id, start_idx = self.indices[idx] traj = self.trajectory_data[traj_id] end_idx = start_idx + self.sequence_length # 如果有完整观测,使用完整观测(107维) if traj_id in self.observation_data and len(self.observation_data[traj_id]) > 0: obs_sequence = self.observation_data[traj_id] states = obs_sequence[start_idx:end_idx] # (seq_len, 107) else: # 否则使用简化观测(5维) positions = traj["positions"][start_idx:end_idx+1] headings = traj["headings"][start_idx:end_idx+1] velocities = traj["velocities"][start_idx:end_idx] states = [] for i in range(self.sequence_length): state = np.concatenate([ positions[i, :2], # x, y velocities[i], # vx, vy [headings[i]], # heading ]) states.append(state) states = np.array(states) if self.extract_actions: positions = traj["positions"][start_idx:end_idx+1] headings = traj["headings"][start_idx:end_idx+1] velocities = traj["velocities"][start_idx:end_idx] actions = self._extract_actions_from_states( positions[:-1], positions[1:], headings[:-1], headings[1:], velocities ) return torch.FloatTensor(states), torch.FloatTensor(actions) else: next_states = states[1:] return torch.FloatTensor(states[:-1]), torch.FloatTensor(next_states) def _extract_actions_from_states(self, pos_t, pos_t1, head_t, head_t1, vel_t): """从状态序列反推动作""" actions = [] dt = 0.1 for i in range(len(pos_t)): current_speed = np.linalg.norm(vel_t[i]) displacement = np.linalg.norm(pos_t1[i, :2] - pos_t[i, :2]) next_speed = displacement / dt speed_change = (next_speed - current_speed) / dt if speed_change >= 0: throttle = np.clip(speed_change / 5.0, 0.0, 1.0) else: throttle = np.clip(speed_change / 8.0, -1.0, 0.0) heading_change = head_t1[i] - head_t[i] heading_change = np.arctan2(np.sin(heading_change), np.cos(heading_change)) steering = np.clip(heading_change / 0.2, -1.0, 1.0) actions.append([throttle, steering]) return np.array(actions) @staticmethod def collect_with_full_obs(env_config, num_scenarios=10, save_path=None): """ ✅ 使用env._get_all_obs()收集完整107维观测 这是正确的方法!直接利用环境已有的观测函数 """ all_trajectories = {} all_observations = {} # 检查数据库 data_dir = env_config["config"]["data_directory"] summary_path = os.path.join(data_dir, "dataset_summary.pkl") with open(summary_path, 'rb') as f: summary = pickle.load(f) total_scenarios = len(summary) print(f"数据库总场景数: {total_scenarios}") if num_scenarios is None: num_scenarios = total_scenarios else: num_scenarios = min(num_scenarios, total_scenarios) print(f"计划收集(完整107维观测): {num_scenarios} 个场景") for i in range(num_scenarios): try: # 创建环境 env = MultiAgentScenarioEnv( config={ **env_config["config"], "start_scenario_index": i, "num_scenarios": 1, }, agent2policy=env_config["agent2policy"] ) # 重置环境 env.reset() if not hasattr(env, 'expert_trajectories'): print(f"⚠️ 场景 {i}: 缺少expert_trajectories") env.close() continue expert_trajs = env.expert_trajectories if len(expert_trajs) == 0: print(f"⚠️ 场景 {i}: 无专家轨迹") env.close() continue # 存储轨迹 scenario_id = env.engine.current_seed for obj_id, traj in expert_trajs.items(): unique_id = f"scenario{i}_{obj_id}" all_trajectories[unique_id] = traj # ✅ 关键: 使用_get_all_obs()获取完整观测 # 创建agent_id到unique_id的映射 agent_to_unique = {} for agent_id in env.controlled_agents.keys(): # 尝试匹配agent_id到expert_trajectories的obj_id for obj_id in expert_trajs.keys(): if str(agent_id) in str(obj_id) or str(obj_id) in str(agent_id): unique_id = f"scenario{i}_{obj_id}" agent_to_unique[agent_id] = unique_id all_observations[unique_id] = [] break # 遍历场景的每一步,收集完整观测 max_steps = min([traj["length"] for traj in expert_trajs.values()]) for step in range(max_steps): # ✅ 直接调用_get_all_obs()获取107维观测! obs_list = env._get_all_obs() # 存储每个agent的观测 for agent_idx, agent_id in enumerate(env.controlled_agents.keys()): if agent_id in agent_to_unique: unique_id = agent_to_unique[agent_id] if agent_idx < len(obs_list): # obs_list[agent_idx]已经是107维向量! all_observations[unique_id].append(np.array(obs_list[agent_idx])) # 执行零动作(保持场景状态) actions = {aid: np.array([0.0, 0.0]) for aid in env.controlled_agents.keys()} env.step(actions) # 转换为numpy数组 for unique_id in list(all_observations.keys()): if len(all_observations[unique_id]) > 0: all_observations[unique_id] = np.array(all_observations[unique_id]) else: del all_observations[unique_id] env.close() if (i + 1) % 5 == 0: print(f"✓ 已收集 {i+1}/{num_scenarios}, " f"轨迹: {len(all_trajectories)}, " f"观测: {len(all_observations)}") except Exception as e: print(f"✗ 场景 {i} 收集失败: {e}") import traceback traceback.print_exc() try: env.close() except: pass continue print(f"\n收集完成!") print(f" 轨迹数: {len(all_trajectories)}") print(f" 完整观测数: {len(all_observations)}") # 验证观测维度 if len(all_observations) > 0: sample_obs = list(all_observations.values())[0] if len(sample_obs) > 0: obs_dim = len(sample_obs[0]) print(f" 观测维度: {obs_dim} (应为107)") if save_path: with open(save_path, "wb") as f: pickle.dump({ "trajectories": all_trajectories, "observations": all_observations }, f) print(f"数据已保存到: {save_path}") return all_trajectories, all_observations if __name__ == "__main__": WAYMO_DATA_DIR = r"/home/huangfukk/mdsn" data_dir = AssetLoader.file_path(WAYMO_DATA_DIR, "exp_filtered", unix_style=False) env_config = { "config": { "data_directory": data_dir, "is_multi_agent": True, "num_controlled_agents": 3, "use_render": False, "sequential_seed": True, }, "agent2policy": DummyPolicy() } print("=" * 60) print("选择收集模式:") print("1. 简化观测(5维) - 快速,已验证 ✅") print("2. 完整观测(107维) - 使用_get_all_obs() ⭐") print("=" * 60) mode = input("请选择模式(1或2,默认1): ").strip() or "1" if mode == "2": print("\n开始收集完整107维观测...") trajectories, observations = ExpertTrajectoryDataset.collect_with_full_obs( env_config, num_scenarios=10, save_path="./expert_trajectories_full.pkl" ) if len(trajectories) > 0: dataset = ExpertTrajectoryDataset( trajectories, observations, sequence_length=1 ) state, action = dataset[0] print(f"\n数据集测试:") print(f" 总轨迹数: {len(trajectories)}") print(f" 总观测数: {len(observations)}") print(f" 训练样本数: {len(dataset)}") print(f" 状态维度: {state.shape}") print(f" 动作维度: {action.shape}") else: print("\n开始收集简化5维观测...") # 保持原有的简化版本代码... print("(使用之前已成功的方法)")