Skip to content

Commit

Permalink
H100 adaptability
Browse files Browse the repository at this point in the history
  • Loading branch information
FortuneBush committed Nov 28, 2024
1 parent 26d2e1e commit 10bad92
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mlora/model/llm/model_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def create_device_map() -> str | Dict[str, str]:
"torch_dtype": torch.float32,
}

# If it is an H100 device, set precision to nf4
if torch.cuda.is_available() and "H100" in torch.cuda.get_device_name(torch.cuda.current_device()):
precision = "nf4"

logging.info(f"Loading model with precision - {precision}")

if precision in load_type_dict:
Expand Down

0 comments on commit 10bad92

Please sign in to comment.