Skip to content

Commit

Permalink
fix bug of undefined flags of easyrec tools run with DeepRec
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong committed Dec 31, 2024
1 parent 9e131bc commit 7f5ddad
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 0 deletions.
3 changes: 3 additions & 0 deletions easy_rec/python/tools/add_boundaries_to_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import json
import logging
import os
import sys

import common_io
import tensorflow as tf

from easy_rec.python.utils import config_util
from easy_rec.python.utils import io_util

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand Down Expand Up @@ -61,4 +63,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/add_feature_info_to_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import json
import logging
import os
import sys

import tensorflow as tf

from easy_rec.python.utils import config_util
from easy_rec.python.utils import io_util
from easy_rec.python.utils.hive_utils import HiveUtils

if tf.__version__ >= '2.0':
Expand Down Expand Up @@ -139,4 +141,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/faiss_index_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import logging
import os
import sys

import faiss
import numpy as np
import tensorflow as tf
from easy_rec.python.utils import io_util

logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')
Expand Down Expand Up @@ -109,4 +111,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
import os
import sys
from collections import OrderedDict

import numpy as np
Expand All @@ -11,6 +12,7 @@
from tensorflow.python.framework.meta_graph import read_meta_graph_file

from easy_rec.python.utils import config_util
from easy_rec.python.utils import io_util

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand Down Expand Up @@ -299,6 +301,7 @@ def _visualize_feature_importance(self, feature_importance, group_name):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
if FLAGS.model_type == 'variational_dropout':
fs = VariationalDropoutFS(
FLAGS.config_path,
Expand Down
3 changes: 3 additions & 0 deletions easy_rec/python/tools/hit_rate_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import json
import logging
import os
import sys

import graphlearn as gl
import tensorflow as tf

from easy_rec.python.protos.dataset_pb2 import DatasetConfig
from easy_rec.python.utils import config_util
from easy_rec.python.utils import io_util
from easy_rec.python.utils.config_util import process_multi_file_input_path
from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
from easy_rec.python.utils.hit_rate_utils import load_graph
Expand Down Expand Up @@ -217,4 +219,5 @@ def main():


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
main()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/hit_rate_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from __future__ import division
from __future__ import print_function

import sys
import tensorflow as tf

from easy_rec.python.utils import io_util
from easy_rec.python.utils.hit_rate_utils import compute_hitrate_batch
from easy_rec.python.utils.hit_rate_utils import load_graph
from easy_rec.python.utils.hit_rate_utils import reduce_hitrate
Expand Down Expand Up @@ -131,4 +133,5 @@ def main():


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
main()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/pre_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import json
import logging
import os
import sys

import tensorflow as tf

from easy_rec.python.input.input import Input
from easy_rec.python.utils import config_util
from easy_rec.python.utils import fg_util
from easy_rec.python.utils import io_util
from easy_rec.python.utils.check_utils import check_env_and_input_path
from easy_rec.python.utils.check_utils import check_sequence

Expand Down Expand Up @@ -114,4 +116,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/split_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import logging
import os
import sys

import tensorflow as tf
from tensorflow.core.framework import graph_pb2
Expand All @@ -11,6 +12,7 @@
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver
from easy_rec.python.utils import io_util

if tf.__version__ >= '2.0':
tf = tf.compat.v1
Expand Down Expand Up @@ -282,4 +284,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
3 changes: 3 additions & 0 deletions easy_rec/python/tools/split_pdn_model_pai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import logging
import os
import sys

import tensorflow as tf
from tensorflow.core.framework import graph_pb2
Expand All @@ -12,6 +13,7 @@
from tensorflow.python.saved_model.utils_impl import get_variables_path
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as tf_saver
from easy_rec.python.utils import io_util

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_dir', '', '')
Expand Down Expand Up @@ -265,4 +267,5 @@ def main(argv):


if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()
21 changes: 21 additions & 0 deletions easy_rec/python/utils/io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,24 @@ def read_data_from_json_path(json_path):
else:
logging.info('json_path not exists, return None')
return None

def filter_unknown_args(flags, args):
"""Filter unknown args."""
defined_flags = set(flag.name for flag in flags._flags().values())
logging.info('defined arguments: %s', ', '.join(defined_flags))
logging.info('actual arguments: %s', ', '.join(args[1:]))
known_args = [args[0]]
unknown = False
for arg in args[1:]:
if arg.startswith('--'):
flag_name = arg.split('=')[0][2:]
if flag_name in defined_flags:
known_args.append(arg)
unknown = False
else:
unknown = True
logging.warning('Ignore unknown arg: %s' % arg)
elif not unknown:
known_args.append(arg)
logging.info('keep arguments: %s', ', '.join(known_args[1:]))
return known_args

0 comments on commit 7f5ddad

Please sign in to comment.