Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mps support and fix #20 #21

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 51 additions & 33 deletions nllb_serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,27 @@

from . import log, DEF_MODEL_ID

device = torch.device(torch.cuda.is_available() and 'cuda' or 'cpu')
# Check if CUDA is available
if torch.cuda.is_available():
device = torch.device('cuda')
# Check if MPS is available (for Apple Silicon)
elif torch.backends.mps.is_available():
device = torch.device('mps')
# If CUDA and MPS is not available, use CPU
else:
device = torch.device('cpu')
log.info(f'torch device={device}')

#DEF_MODEL_ID = "facebook/nllb-200-distilled-600M"
# DEF_MODEL_ID = "facebook/nllb-200-distilled-600M"
DEF_SRC_LNG = 'eng_Latn'
DEF_TGT_LNG = 'kan_Knda'
FLOAT_POINTS = 4
exp = None
app = Flask(__name__)
app.json.ensure_ascii = False

bp = Blueprint('nmt', __name__, template_folder='templates', static_folder='static')
bp = Blueprint('nmt', __name__, template_folder='templates',
static_folder='static')


sys_info = {
Expand All @@ -42,22 +51,25 @@
'GPU': '[unavailable]',
}
try:
sys_info['torch']: torch.__version__
sys_info['torch'] = torch.__version__
if torch.cuda.is_available():
sys_info['GPU'] = str(torch.cuda.get_device_properties('cuda'))
sys_info['Cuda Version'] = torch.version.cuda
elif torch.backends.mps.is_available():
sys_info['GPU'] = 'Apple MPS'
sys_info['MPS Version'] = torch.version.__version__
else:
log.warning("CUDA unavailable")
log.warning("CUDA/MPS unavailable")
except:
log.exception("Error while checking if cuda is available")
pass


def render_template(*args, **kwargs):
return flask.render_template(*args, environ=os.environ, **kwargs)


def jsonify(obj):

if obj is None or isinstance(obj, (int, bool, str)):
return obj
elif isinstance(obj, float):
Expand All @@ -66,7 +78,7 @@ def jsonify(obj):
return {key: jsonify(val) for key, val in obj.items()}
elif isinstance(obj, list):
return [jsonify(it) for it in obj]
#elif isinstance(ob, np.ndarray):
# elif isinstance(ob, np.ndarray):
# return _jsonify(ob.tolist())
else:
log.warning(f"Type {type(obj)} maybe not be json serializable")
Expand All @@ -77,9 +89,10 @@ def jsonify(obj):
def favicon():
return send_from_directory(os.path.join(bp.root_path, 'static', 'favicon'), 'favicon.ico')


def attach_translate_route(
model_id=DEF_MODEL_ID, def_src_lang=DEF_SRC_LNG,
def_tgt_lang=DEF_TGT_LNG, **kwargs):
model_id=DEF_MODEL_ID, def_src_lang=DEF_SRC_LNG,
def_tgt_lang=DEF_TGT_LNG, **kwargs):
sys_info['model_id'] = model_id
torch.set_grad_enabled(False)

Expand All @@ -93,7 +106,7 @@ def attach_translate_route(
@lru_cache(maxsize=256)
def get_tokenizer(src_lang=def_src_lang):
log.info(f"Loading tokenizer for {model_id}; src_lang={src_lang} ...")
#tokenizer = AutoTokenizer.from_pretrained(model_id)
# tokenizer = AutoTokenizer.from_pretrained(model_id)
return AutoTokenizer.from_pretrained(model_id, src_lang=src_lang)

@bp.route('/')
Expand All @@ -102,7 +115,6 @@ def index():
def_src_lang=def_src_lang, def_tgt_lang=def_tgt_lang)
return render_template('index.html', **args)


@bp.route("/translate", methods=["POST", "GET"])
def translate():
st = time.time()
Expand All @@ -116,7 +128,7 @@ def translate():
else:
args = request.form

if hasattr(args, 'getlist') :
if hasattr(args, 'getlist'):
sources = args.getlist("source")
else:
sources = args.get("source")
Expand All @@ -126,37 +138,38 @@ def translate():
src_lang = args.get('src_lang') or def_src_lang
tgt_lang = args.get('tgt_lang') or def_tgt_lang
sen_split = args.get('sen_split')

tokenizer = get_tokenizer(src_lang=src_lang)

if not sources:
return "Please submit 'source' parameter", 400

if sen_split:
if not ssplit_lang(src_lang):
return "Sentence splitter for this langauges is not availabe", 400
sources, index = sentence_splitter(src_lang, sources)
return "Sentence splitter for this langauges is not availabe", 400
sources, index = sentence_splitter(src_lang, sources)

max_length = 80
inputs = tokenizer(sources, return_tensors="pt", padding=True)
inputs = {k:v.to(device) for k, v in inputs.items()}
inputs = {k: v.to(device) for k, v in inputs.items()}

translated_tokens = model.generate(
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
max_length = max_length)
output = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)

**inputs, forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
max_length=max_length)
output = tokenizer.batch_decode(
translated_tokens, skip_special_tokens=True)

if sen_split:
results = []
results = []
for i in range(1, len(index)):
batch = output[index[i-1]:index[i]]
results.append(" ".join(batch))
else:
results = output
results = output

res = dict(source=sources, translation=results,
src_lang = src_lang, tgt_lang=tgt_lang,
time_taken = round(time.time() - st, 3), time_units='s')
src_lang=src_lang, tgt_lang=tgt_lang,
time_taken=round(time.time() - st, 3), time_units='s')

return flask.jsonify(jsonify(res))

Expand All @@ -169,12 +182,17 @@ def parse_args():
parser = ArgumentParser(
prog="nllb-serve",
description="Deploy NLLB model to a RESTful server",
epilog=f'Loaded from {__file__}. Source code: https://github.com/thammegowda/nllb-serve',
epilog=f'Loaded from {
__file__}. Source code: https://github.com/thammegowda/nllb-serve',
formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("-d", "--debug", action="store_true", help="Run Flask server in debug mode")
parser.add_argument("-p", "--port", type=int, help="port to run server on", default=6060)
parser.add_argument("-ho", "--host", help="Host address to bind.", default='0.0.0.0')
parser.add_argument("-b", "--base", help="Base prefix path for all the URLs. E.g., /v1")
parser.add_argument("-d", "--debug", action="store_true",
help="Run Flask server in debug mode")
parser.add_argument("-p", "--port", type=int,
help="port to run server on", default=6060)
parser.add_argument(
"-ho", "--host", help="Host address to bind.", default='0.0.0.0')
parser.add_argument(
"-b", "--base", help="Base prefix path for all the URLs. E.g., /v1")
parser.add_argument("-mi", "--model_id", type=str, default=DEF_MODEL_ID,
help="model ID; see https://huggingface.co/models?other=nllb")
parser.add_argument("-msl", "--max-src-len", type=int, default=250,
Expand Down Expand Up @@ -207,4 +225,4 @@ def main():


if __name__ == "__main__":
main()
main()
Loading