Skip to content

Commit

Permalink
update amp
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Dec 20, 2024
1 parent 5cdd9e7 commit 06a3303
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
if _TORCH_GREATER_EQUAL_2_4
else getattr(
torch,
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
).amp.GradScaler()
)
if scaler is not None and self.precision == "bf16-mixed":
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
if _TORCH_GREATER_EQUAL_2_4
else getattr(
torch,
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
).amp.GradScaler()
)
if scaler is not None and self.precision == "bf16-mixed":
Expand Down

0 comments on commit 06a3303

Please sign in to comment.