Skip to content

Commit

Permalink
[feat] add total_loss to the plogger and summary_writer (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-gecheng authored Dec 30, 2024
1 parent 011b02b commit fead320
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,16 @@ def _log_train(
loss_strs.append(f"{k}:{v:.5f}")
for i, g in enumerate(param_groups):
lr_strs.append(f"lr_g{i}:{g['lr']:.5f}")
plogger.log(step, f"{' '.join(lr_strs)} {' '.join(loss_strs)}")
total_loss = sum(losses.values())
plogger.log(
step,
f"{' '.join(lr_strs)} {' '.join(loss_strs)} total_loss: {total_loss:.5f}",
)
if summary_writer is not None:
total_loss = sum(losses.values())
for k, v in losses.items():
summary_writer.add_scalar(f"loss/{k}", v, step)
summary_writer.add_scalar("loss/total", total_loss, step)
for i, g in enumerate(param_groups):
summary_writer.add_scalar(f"lr/g{i}", g["lr"], step)

Expand Down

0 comments on commit fead320

Please sign in to comment.