Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WandbLogger logging step incorrectly if project name not passed. #20383

Closed
golmschenk opened this issue Nov 1, 2024 · 1 comment
Closed

WandbLogger logging step incorrectly if project name not passed. #20383

golmschenk opened this issue Nov 1, 2024 · 1 comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x

Comments

@golmschenk
Copy link

golmschenk commented Nov 1, 2024

Bug description

If the project name is not passed (left as None), when creating a WandbLogger, the logger logs the global step differently.

What version are you seeing the problem on?

v2.3

How to reproduce the bug

In the below minimal example, project is not passed to WandbLogger and the incorrect step is logged:

import os

import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers.wandb import WandbLogger
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)
        self.loss_metric = CrossEntropyLoss()
        self.train_loss_total = torch.tensor(0, dtype=torch.float32)
        self.validation_loss_total = torch.tensor(0, dtype=torch.float32)
        self.train_steps = torch.tensor(0, dtype=torch.int64)
        self.validation_steps = torch.tensor(0, dtype=torch.int64)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_metric(y_hat, y)
        self.train_loss_total += loss
        self.train_steps += 1
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_metric(y_hat, y)
        self.validation_loss_total += loss
        self.validation_steps += 1
        return loss

    def on_train_epoch_end(self):
        self.log('loss', self.train_loss_total / self.train_steps, on_step=False, on_epoch=True)
        self.train_loss_total.zero_()
        self.train_steps.zero_()

    def on_validation_epoch_end(self):
        self.log('val_loss', self.validation_loss_total / self.validation_steps, on_step=False, on_epoch=True)
        self.validation_loss_total.zero_()
        self.validation_steps.zero_()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


train_loader = DataLoader(
    MNIST(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
)
val_loader = DataLoader(
    MNIST(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())
)
logger = WandbLogger(project='a')
trainer = pl.Trainer(
    max_epochs=10,
    limit_train_batches=20,
    limit_val_batches=10,
    log_every_n_steps=0,
    logger=logger,
    accelerator='cpu',
)
model = LitModel()

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

This results in the step being incremented only at the end of the epoch, and once for both training and validation, meaning they have different step numbers (each of which less than the global training step numbers).

screenshot_2024_11_01_13_45_30

By only changing WandbLogger() to WandbLogger(project='a'), it changes how the step value is logged. With this change, the step value is logged as the global training step value, and is consistent between the train and validation logging.

screenshot_2024_11_01_13_45_50

Environment

Current environment
  • CUDA:
    - GPU: None
    - available: False
    - version: None
  • Lightning:
    - lightning: 2.3.3
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.3.3
    - torch: 2.3.1
    - torcheval: 0.0.7
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.18.1
  • Packages:
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - alabaster: 0.7.16
    - anyio: 4.4.0
    - appnope: 0.1.4
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - astropy: 6.1.0
    - astropy-iers-data: 0.2024.6.3.0.31.14
    - astroquery: 0.4.7
    - asttokens: 2.4.1
    - async-lru: 2.0.4
    - atpublic: 4.1.0
    - attrs: 23.2.0
    - autograd: 1.6.2
    - babel: 2.14.0
    - backcall: 0.2.0
    - backports-strenum: 1.2.8
    - backports.tarfile: 1.2.0
    - beautifulsoup4: 4.12.3
    - bleach: 6.1.0
    - bokeh: 3.5.1
    - brotli: 1.1.0
    - cached-property: 1.5.2
    - certifi: 2024.6.2
    - cffi: 1.17.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - comm: 0.2.2
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - debugpy: 1.8.5
    - decorator: 5.1.1
    - defusedxml: 0.7.1
    - distlib: 0.3.8
    - docker-pycreds: 0.4.0
    - docopt: 0.6.2
    - docutils: 0.21.2
    - entrypoints: 0.4
    - exceptiongroup: 1.2.2
    - executing: 2.0.1
    - fastjsonschema: 2.19.1
    - fbpca: 1.0
    - filelock: 3.14.0
    - fonttools: 4.53.0
    - fqdn: 1.5.1
    - frozenlist: 1.4.1
    - fsspec: 2024.6.0
    - furo: 2024.5.6
    - future: 1.0.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - h11: 0.14.0
    - h2: 4.1.0
    - hatch: 1.12.0
    - hatchling: 1.24.2
    - hpack: 4.0.0
    - html5lib: 1.1
    - httpcore: 1.0.5
    - httpx: 0.27.0
    - humanize: 4.11.0
    - hyperframe: 6.0.1
    - hyperlink: 21.0.0
    - idna: 3.7
    - imagesize: 1.4.1
    - importlib-metadata: 7.1.0
    - importlib-resources: 6.4.0
    - iniconfig: 2.0.0
    - ipykernel: 6.29.5
    - ipython: 8.12.3
    - ipywidgets: 8.1.3
    - isoduration: 20.11.0
    - jaraco.classes: 3.4.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - joblib: 1.4.2
    - json5: 0.9.25
    - jsonpointer: 3.0.0
    - jsonschema: 4.22.0
    - jsonschema-specifications: 2023.12.1
    - jupyter: 1.0.0
    - jupyter-client: 8.6.2
    - jupyter-console: 6.6.3
    - jupyter-core: 5.7.2
    - jupyter-events: 0.10.0
    - jupyter-lsp: 2.2.5
    - jupyter-server: 2.14.2
    - jupyter-server-terminals: 0.5.3
    - jupyterlab: 4.2.4
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.3
    - jupyterlab-widgets: 3.0.11
    - keyring: 25.2.1
    - kiwisolver: 1.4.5
    - lightkurve: 2.4.2
    - lightning: 2.3.3
    - lightning-utilities: 0.11.2
    - lxml: 5.2.2
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.9.0
    - matplotlib-inline: 0.1.7
    - mdit-py-plugins: 0.4.1
    - mdurl: 0.1.2
    - memoization: 0.4.0
    - mistune: 3.0.2
    - more-itertools: 10.2.0
    - mpmath: 1.3.0
    - multidict: 6.0.5
    - myst-parser: 3.0.1
    - nbclient: 0.10.0
    - nbconvert: 7.16.4
    - nbformat: 5.10.4
    - nest-asyncio: 1.6.0
    - networkx: 3.3
    - notebook: 7.2.1
    - notebook-shim: 0.2.4
    - numpy: 1.26.4
    - oktopus: 0.1.2
    - overrides: 7.7.0
    - packaging: 24.0
    - pandas: 2.2.2
    - pandocfilters: 1.5.0
    - parso: 0.8.4
    - pathspec: 0.12.1
    - patsy: 0.5.6
    - peewee: 3.17.5
    - pexpect: 4.9.0
    - pickleshare: 0.7.5
    - pillow: 10.3.0
    - pip: 24.0
    - pipreqs: 0.5.0
    - pkgutil-resolve-name: 1.3.10
    - platformdirs: 4.2.2
    - plotly: 5.22.0
    - pluggy: 1.5.0
    - polars: 0.20.31
    - prometheus-client: 0.20.0
    - prompt-toolkit: 3.0.46
    - protobuf: 5.27.1
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - pyarrow: 16.1.0
    - pycparser: 2.22
    - pyerfa: 2.0.1.4
    - pygments: 2.18.0
    - pyobjc-core: 10.3.1
    - pyobjc-framework-cocoa: 10.3.1
    - pyparsing: 3.1.2
    - pysocks: 1.7.1
    - pytest: 7.4.4
    - pytest-asyncio: 0.23.7
    - pytest-pycharm: 0.7.0
    - python-dateutil: 2.9.0
    - python-json-logger: 2.0.7
    - pytorch-lightning: 2.3.3
    - pytz: 2024.1
    - pyvo: 1.5.2
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - qtconsole: 5.5.2
    - qtpy: 2.4.1
    - qusi: 1.0.3
    - qusi-evaluation: 0.0.1
    - referencing: 0.35.1
    - requests: 2.32.3
    - retrying: 1.3.4
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.7.1
    - rpds-py: 0.18.1
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - send2trash: 1.8.3
    - sentry-sdk: 2.5.1
    - setproctitle: 1.3.3
    - setuptools: 70.0.0
    - shellingham: 1.5.4
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - snowballstemmer: 2.2.0
    - soupsieve: 2.5
    - sphinx: 7.3.7
    - sphinx-basic-ng: 1.0.0b2
    - sphinxcontrib-applehelp: 1.0.8
    - sphinxcontrib-devhelp: 1.0.6
    - sphinxcontrib-htmlhelp: 2.0.5
    - sphinxcontrib-jsmath: 1.0.1
    - sphinxcontrib-qthelp: 1.0.7
    - sphinxcontrib-serializinghtml: 1.1.10
    - stack-data: 0.6.2
    - stringcase: 1.2.0
    - sympy: 1.12.1
    - tenacity: 8.3.0
    - terminado: 0.18.1
    - threadpoolctl: 3.5.0
    - tinycss2: 1.3.0
    - tomli: 2.0.1
    - tomli-w: 1.0.0
    - tomlkit: 0.12.5
    - torch: 2.3.1
    - torcheval: 0.0.7
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.18.1
    - tornado: 6.4.1
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - trove-classifiers: 2024.5.22
    - types-python-dateutil: 2.9.0.20240316
    - typing-extensions: 4.12.2
    - typing-utils: 0.1.0
    - tzdata: 2024.1
    - uncertainties: 3.2.1
    - uri-template: 1.3.0
    - urllib3: 2.2.1
    - userpath: 1.9.2
    - uv: 0.2.11
    - uvloop: 0.19.0
    - virtualenv: 20.26.2
    - wandb: 0.17.1
    - wcwidth: 0.2.13
    - webcolors: 24.6.0
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - wget: 3.2
    - wheel: 0.43.0
    - widgetsnbextension: 4.0.11
    - xyzservices: 2024.6.0
    - yarg: 0.1.9
    - yarl: 1.9.4
    - zenodo-get: 1.6.1
    - zipp: 3.19.2
    - zstandard: 0.22.0
  • System:
    - OS: Darwin
    - architecture:
    - 64bit
    -
    - processor: arm
    - python: 3.11.9
    - release: 24.0.0
    - version: Darwin Kernel Version 24.0.0: Tue Sep 24 23:39:07 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6000
@golmschenk golmschenk added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 1, 2024
@golmschenk
Copy link
Author

It appears having run a previous run with Step overrode the trainer/global_step as the default x-axis value. After clearing the default Wandb project (lightning_logs), then running this code again produced the correct results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.3.x
Projects
None yet
Development

No branches or pull requests

1 participant