diff --git a/tzrec/main.py b/tzrec/main.py index 95b8af4..2cb86b3 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -309,10 +309,16 @@ def _log_train( loss_strs.append(f"{k}:{v:.5f}") for i, g in enumerate(param_groups): lr_strs.append(f"lr_g{i}:{g['lr']:.5f}") - plogger.log(step, f"{' '.join(lr_strs)} {' '.join(loss_strs)}") + total_loss = sum(losses.values()) + plogger.log( + step, + f"{' '.join(lr_strs)} {' '.join(loss_strs)} total_loss: {total_loss:.5f}", + ) if summary_writer is not None: + total_loss = sum(losses.values()) for k, v in losses.items(): summary_writer.add_scalar(f"loss/{k}", v, step) + summary_writer.add_scalar("loss/total", total_loss, step) for i, g in enumerate(param_groups): summary_writer.add_scalar(f"lr/g{i}", g["lr"], step)