WandbLogger logging step incorrectly if project name not passed. #20383

golmschenk opened this issue Nov 1, 2024 · 1 comment

bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x


golmschenk commented Nov 1, 2024

Bug description

If the project name is not passed (left as None), when creating a WandbLogger, the logger logs the global step differently.

What version are you seeing the problem on?


How to reproduce the bug

In the below minimal example, project is not passed to WandbLogger and the incorrect step is logged:

import os

import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers.wandb import WandbLogger
from torch import nn
from torch.nn import CrossEntropyLoss
from import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

class LitModel(pl.LightningModule):
    def __init__(self):
        self.l1 = nn.Linear(28 * 28, 10)
        self.loss_metric = CrossEntropyLoss()
        self.train_loss_total = torch.tensor(0, dtype=torch.float32)
        self.validation_loss_total = torch.tensor(0, dtype=torch.float32)
        self.train_steps = torch.tensor(0, dtype=torch.int64)
        self.validation_steps = torch.tensor(0, dtype=torch.int64)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_metric(y_hat, y)
        self.train_loss_total += loss
        self.train_steps += 1
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_metric(y_hat, y)
        self.validation_loss_total += loss
        self.validation_steps += 1
        return loss

    def on_train_epoch_end(self):
        self.log('loss', self.train_loss_total / self.train_steps, on_step=False, on_epoch=True)

    def on_validation_epoch_end(self):
        self.log('val_loss', self.validation_loss_total / self.validation_steps, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

train_loader = DataLoader(
    MNIST(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
val_loader = DataLoader(
    MNIST(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())
logger = WandbLogger(project='a')
trainer = pl.Trainer(
model = LitModel(), train_dataloaders=train_loader, val_dataloaders=val_loader)

This results in the step being incremented only at the end of the epoch, and once for both training and validation, meaning they have different step numbers (each of which less than the global training step numbers).


By only changing WandbLogger() to WandbLogger(project='a'), it changes how the step value is logged. With this change, the step value is logged as the global training step value, and is consistent between the train and validation logging.



Current environment
  • CUDA:
    - GPU: None
    - available: False
    - version: None
  • Lightning:
    - lightning: 2.3.3
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.3.3
    - torch: 2.3.1
    - torcheval: 0.0.7
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.18.1
  • Packages:
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - alabaster: 0.7.16
    - anyio: 4.4.0
    - appnope: 0.1.4
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - astropy: 6.1.0
    - astropy-iers-data: 0.2024.
    - astroquery: 0.4.7
    - asttokens: 2.4.1
    - async-lru: 2.0.4
    - atpublic: 4.1.0
    - attrs: 23.2.0
    - autograd: 1.6.2
    - babel: 2.14.0
    - backcall: 0.2.0
    - backports-strenum: 1.2.8
    - backports.tarfile: 1.2.0
    - beautifulsoup4: 4.12.3
    - bleach: 6.1.0
    - bokeh: 3.5.1
    - brotli: 1.1.0
    - cached-property: 1.5.2
    - certifi: 2024.6.2
    - cffi: 1.17.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - comm: 0.2.2
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - debugpy: 1.8.5
    - decorator: 5.1.1
    - defusedxml: 0.7.1
    - distlib: 0.3.8
    - docker-pycreds: 0.4.0
    - docopt: 0.6.2
    - docutils: 0.21.2
    - entrypoints: 0.4
    - exceptiongroup: 1.2.2
    - executing: 2.0.1
    - fastjsonschema: 2.19.1
    - fbpca: 1.0
    - filelock: 3.14.0
    - fonttools: 4.53.0
    - fqdn: 1.5.1
    - frozenlist: 1.4.1
    - fsspec: 2024.6.0
    - furo: 2024.5.6
    - future: 1.0.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - h11: 0.14.0
    - h2: 4.1.0
    - hatch: 1.12.0
    - hatchling: 1.24.2
    - hpack: 4.0.0
    - html5lib: 1.1
    - httpcore: 1.0.5
    - httpx: 0.27.0
    - humanize: 4.11.0
    - hyperframe: 6.0.1
    - hyperlink: 21.0.0
    - idna: 3.7
    - imagesize: 1.4.1
    - importlib-metadata: 7.1.0
    - importlib-resources: 6.4.0
    - iniconfig: 2.0.0
    - ipykernel: 6.29.5
    - ipython: 8.12.3
    - ipywidgets: 8.1.3
    - isoduration: 20.11.0
    - jaraco.classes: 3.4.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - joblib: 1.4.2
    - json5: 0.9.25
    - jsonpointer: 3.0.0
    - jsonschema: 4.22.0
    - jsonschema-specifications: 2023.12.1
    - jupyter: 1.0.0
    - jupyter-client: 8.6.2
    - jupyter-console: 6.6.3
    - jupyter-core: 5.7.2
    - jupyter-events: 0.10.0
    - jupyter-lsp: 2.2.5
    - jupyter-server: 2.14.2
    - jupyter-server-terminals: 0.5.3
    - jupyterlab: 4.2.4
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.3
    - jupyterlab-widgets: 3.0.11
    - keyring: 25.2.1
    - kiwisolver: 1.4.5
    - lightkurve: 2.4.2
    - lightning: 2.3.3
    - lightning-utilities: 0.11.2
    - lxml: 5.2.2
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.9.0
    - matplotlib-inline: 0.1.7
    - mdit-py-plugins: 0.4.1
    - mdurl: 0.1.2
    - memoization: 0.4.0
    - mistune: 3.0.2
    - more-itertools: 10.2.0
    - mpmath: 1.3.0
    - multidict: 6.0.5
    - myst-parser: 3.0.1
    - nbclient: 0.10.0
    - nbconvert: 7.16.4
    - nbformat: 5.10.4
    - nest-asyncio: 1.6.0
    - networkx: 3.3
    - notebook: 7.2.1
    - notebook-shim: 0.2.4
    - numpy: 1.26.4
    - oktopus: 0.1.2
    - overrides: 7.7.0
    - packaging: 24.0
    - pandas: 2.2.2
    - pandocfilters: 1.5.0
    - parso: 0.8.4
    - pathspec: 0.12.1
    - patsy: 0.5.6
    - peewee: 3.17.5
    - pexpect: 4.9.0
    - pickleshare: 0.7.5
    - pillow: 10.3.0
    - pip: 24.0
    - pipreqs: 0.5.0
    - pkgutil-resolve-name: 1.3.10
    - platformdirs: 4.2.2
    - plotly: 5.22.0
    - pluggy: 1.5.0
    - polars: 0.20.31
    - prometheus-client: 0.20.0
    - prompt-toolkit: 3.0.46
    - protobuf: 5.27.1
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - pyarrow: 16.1.0
    - pycparser: 2.22
    - pyerfa:
    - pygments: 2.18.0
    - pyobjc-core: 10.3.1
    - pyobjc-framework-cocoa: 10.3.1
    - pyparsing: 3.1.2
    - pysocks: 1.7.1
    - pytest: 7.4.4
    - pytest-asyncio: 0.23.7
    - pytest-pycharm: 0.7.0
    - python-dateutil: 2.9.0
    - python-json-logger: 2.0.7
    - pytorch-lightning: 2.3.3
    - pytz: 2024.1
    - pyvo: 1.5.2
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - qtconsole: 5.5.2
    - qtpy: 2.4.1
    - qusi: 1.0.3
    - qusi-evaluation: 0.0.1
    - referencing: 0.35.1
    - requests: 2.32.3
    - retrying: 1.3.4
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.7.1
    - rpds-py: 0.18.1
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - send2trash: 1.8.3
    - sentry-sdk: 2.5.1
    - setproctitle: 1.3.3
    - setuptools: 70.0.0
    - shellingham: 1.5.4
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - snowballstemmer: 2.2.0
    - soupsieve: 2.5
    - sphinx: 7.3.7
    - sphinx-basic-ng: 1.0.0b2
    - sphinxcontrib-applehelp: 1.0.8
    - sphinxcontrib-devhelp: 1.0.6
    - sphinxcontrib-htmlhelp: 2.0.5
    - sphinxcontrib-jsmath: 1.0.1
    - sphinxcontrib-qthelp: 1.0.7
    - sphinxcontrib-serializinghtml: 1.1.10
    - stack-data: 0.6.2
    - stringcase: 1.2.0
    - sympy: 1.12.1
    - tenacity: 8.3.0
    - terminado: 0.18.1
    - threadpoolctl: 3.5.0
    - tinycss2: 1.3.0
    - tomli: 2.0.1
    - tomli-w: 1.0.0
    - tomlkit: 0.12.5
    - torch: 2.3.1
    - torcheval: 0.0.7
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.18.1
    - tornado: 6.4.1
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - trove-classifiers: 2024.5.22
    - types-python-dateutil:
    - typing-extensions: 4.12.2
    - typing-utils: 0.1.0
    - tzdata: 2024.1
    - uncertainties: 3.2.1
    - uri-template: 1.3.0
    - urllib3: 2.2.1
    - userpath: 1.9.2
    - uv: 0.2.11
    - uvloop: 0.19.0
    - virtualenv: 20.26.2
    - wandb: 0.17.1
    - wcwidth: 0.2.13
    - webcolors: 24.6.0
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - wget: 3.2
    - wheel: 0.43.0
    - widgetsnbextension: 4.0.11
    - xyzservices: 2024.6.0
    - yarg: 0.1.9
    - yarl: 1.9.4
    - zenodo-get: 1.6.1
    - zipp: 3.19.2
    - zstandard: 0.22.0
  • System:
    - OS: Darwin
    - architecture:
    - 64bit
    - processor: arm
    - python: 3.11.9
    - release: 24.0.0
    - version: Darwin Kernel Version 24.0.0: Tue Sep 24 23:39:07 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6000
@golmschenk golmschenk added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 1, 2024
It appears having run a previous run with Step overrode the trainer/global_step as the default x-axis value. After clearing the default Wandb project (lightning_logs), then running this code again produced the correct results.

bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x
