suppres tf warning

This commit is contained in:
QuanyiLi
2023-05-06 23:32:16 +01:00
parent 300a29da09
commit 2d5eb33e04
2 changed files with 6 additions and 5 deletions

View File

@@ -9,11 +9,14 @@ logger = logging.getLogger(__name__)
import numpy as np import numpy as np
try: try:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL
logging.getLogger('tensorflow').setLevel(logging.FATAL)
import tensorflow as tf import tensorflow as tf
except ImportError as e: except ImportError as e:
logger.info(e) logger.info(e)
try: try:
from waymo_open_dataset.protos.scenario_pb2 import Scenario from waymo_open_dataset.protos import scenario_pb2
except ImportError as e: except ImportError as e:
logger.warning(e, "\n Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0==1.5.0") logger.warning(e, "\n Please install waymo_open_dataset package: pip install waymo-open-dataset-tf-2-11-0==1.5.0")
@@ -429,7 +432,7 @@ def get_waymo_scenarios(waymo_data_direction):
if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)): if ("tfrecord" not in file_path) or (not os.path.isfile(file_path)):
continue continue
for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator(): for data in tf.data.TFRecordDataset(file_path, compression_type="").as_numpy_iterator():
scenario = Scenario() scenario = scenario_pb2.Scenario()
scenario.ParseFromString(data) scenario.ParseFromString(data)
# a trick for loging file name # a trick for loging file name
scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file scenario.scenario_id = scenario.scenario_id + SPLIT_KEY + file

View File

@@ -1,13 +1,11 @@
import logging import logging
import os import os
from scenarionet import SCENARIONET_DATASET_PATH
from scenarionet.converter.utils import write_to_directory from scenarionet.converter.utils import write_to_directory
from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios from scenarionet.converter.waymo.utils import convert_waymo_scenario, get_waymo_scenarios
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from scenarionet import SCENARIONET_DATASET_PATH
if __name__ == '__main__': if __name__ == '__main__':
force_overwrite = True force_overwrite = True
dataset_name = "waymo" dataset_name = "waymo"