Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add alchembert to Matbench Discovery leaderboard #187

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
60 changes: 60 additions & 0 deletions models/alchembert/alchembert.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
model_name: AlchemBERT
model_key: alchembert
model_version: 1.0.0
matbench_discovery_version: 1.0.0
date_added: "2024-12-25"
date_published: "2024-12-11"
authors:
- name: Xiaotong Liu
affiliation: Beijing Information Science and Technology University
email: [email protected]
- name: Yuhang Wang
affiliation: Beijing Information Science and Technology University
email: [email protected]
repo: https://gitee.com/liuxiaotong15/alchemBERT
doi: https://doi.org/10.26434/chemrxiv-2024-r4dnl
paper: https://chemrxiv.org/engage/chemrxiv/article-details/67540a28085116a133a62b85
pr_url: https://github.com/janosh/matbench-discovery/pull/186
targets: E
train_task: RS2RE
test_task: IS2RE
trained_for_benchmark: true

requirements:
torch: 2.5.1
lightning: 2.4.0
transformers: 4.46.3
pymatgen: 2024.11.13

model_type: Transformer
model_params: 110_000_000
n_estimators: 1

training_set: [MP 2022]

notes:
description: We tested BERT on IS2RE task of Matbench Discovery.

metrics:
discovery:
pred_file: /models/alchembert/2024-12-25-alchembert-wbm-IS2RE.csv.gz
pred_col: e_form_per_atom_alchembert
full_test_set:
MAE: 0.1190
RMSE: 0.1767
R2: 0.0416
DAF: 1.7768
Precision: 0.2961
Recall: 0.6679
Accuracy: 0.6800
F1: 0.4103
TPR: 0.6679
FPR: 0.3174
TNR: 0.6825
FNR: 0.3320
missing_percent: 0.00%,
missing_preds: 0
TP: 28604
FP: 67988
TN: 146150
FN: 14221
15 changes: 15 additions & 0 deletions models/alchembert/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
We tested the IS2RE task of Matbench - Discovery using BERT. The training - validation ratio was 9:1. When the learning rate was 10^-5, the training stopped approximately at 500 epochs.

val_MAE=0.0674
test_MAE=0.1190
TP, TN, FP, FN: 28604 146150 67988 14221
F1=0.4103

If you would like to try training by yourself, here are our links:
https://gitee.com/liuxiaotong15/alchemBERT
https://github.com/WangYuHang-WYH/AlchemBERT

Here is preprint:
https://chemrxiv.org/engage/chemrxiv/article-details/67540a28085116a133a62b85

If you have any feedback or suggestions on our work, please feel free to contact us. Thank you.
61 changes: 61 additions & 0 deletions models/alchembert/test_AlchemBERT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import lightning as l
import pandas as pd
import torch
from fire import Fire
from torch.utils.data import DataLoader
from train_AlchemBERT import MatBert, MyDataset, task

torch.manual_seed(42)


def get_test_data() -> pd.Series:
target_col = "e_form_per_atom_mp2020_corrected"
df_wbm = pd.read_csv("2022-10-19-wbm-summary.csv")
return pd.Series(df_wbm[target_col])


def get_train_data() -> pd.Series:
target_col = "formation_energy_per_atom"
id_col = "material_id"
df_eng = pd.read_csv("2023-01-10-mp-energies.csv").set_index(id_col)
return pd.Series(df_eng[target_col], index=df_eng.index)


bert_path = "bert-base-cased"
predictions_path = "2024-12-25-alchembert-wbm-IS2RE.csv.gz"
test_pad_cased_path = f"test_{task}_pad_cased_inputs.json"


# %%
def main(best_epoch: int, val_mae: float) -> None:
best_model = f"epoch={best_epoch}_val_MAE={val_mae}_best_model.ckpt"
best_model_path = f"checkpoints/model_epoch5000_{task}/{best_model}"
test_inputs = pd.read_json(test_pad_cased_path)
test_outputs = get_test_data()

best_model = torch.load(best_model_path, weights_only=True)
model = MatBert(bert_path)
model.load_state_dict(best_model["state_dict"])
model.eval()

# %% test
trainer = l.Trainer(accelerator="gpu", devices=[0])

test_input_ids = torch.tensor(test_inputs["input_ids"])
test_attention_mask = torch.tensor(test_inputs["attention_mask"])
test_outputs = torch.tensor(test_outputs.values)
test_dataset = MyDataset(test_input_ids, test_attention_mask, test_outputs)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
predictions = trainer.predict(model, test_loader)

predictions = [tensor.cpu().item() for tensor in predictions]
results = {"e_form_per_atom_alchembert": predictions}
results = pd.DataFrame(results)
print(results)
results.to_csv(predictions_path, index=False, compression="gzip")
print(predictions_path)


# %%
if __name__ == "__main__":
Fire(main)
161 changes: 161 additions & 0 deletions models/alchembert/train_AlchemBERT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import os
import sys
import warnings

import lightning as l
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as f
from fire import Fire
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from test_AlchemBERT import get_train_data
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import BertConfig, BertModel

seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# os.environ["CUDA_VISIBLE_DEVICES"] = "7"

max_length = 512
train_batch_size = 32
val_batch_size = 32
epoch = 5000
patience = 200
log_every_n_steps = 50
save_top_k = 1
l_r = 1e-5

task = "nl"
bert_path = "bert-base-cased"

train_pad_cased_path = "train_nl_pad_cased_inputs.json"
test_pad_cased_path = "test_nl_pad_cased_inputs.json"


# %%
class MyDataset(Dataset):
def __init__(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor,
) -> None:
self.input_ids = input_ids
self.attention_mask = attention_mask
self.labels = labels

def __len__(self) -> int:
return len(self.labels)

def __getitem__(
self, index: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return self.input_ids[index], self.attention_mask[index], self.labels[index]


# %%
class MatBert(l.LightningModule):
def __init__(self, b_path: str) -> None:
super().__init__()
self.bert = BertModel.from_pretrained(b_path, output_hidden_states=True)
self.config = BertConfig.from_pretrained(bert_path)
self.linear = nn.Linear(self.config.hidden_size, 1)

def forward(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_representation = outputs.last_hidden_state[:, 0, :]
return self.linear(cls_representation).squeeze(-1)

def training_step(
self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
) -> torch.Tensor:
input_ids, attention_mask, y = batch
input_ids.cuda()
attention_mask.cuda()
y.cuda()
y_hat = self(input_ids, attention_mask)
loss = f.mse_loss(y_hat.float(), y.float())
self.log("train_mse_loss", loss, on_epoch=True, sync_dist=True)
return loss

def validation_step(
self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
) -> dict[str, torch.Tensor]:
input_ids, attention_mask, y = batch
input_ids.cuda()
attention_mask.cuda()
y.cuda()
y_hat = self(input_ids, attention_mask)
loss = nn.functional.mse_loss(y_hat.float(), y.float())
mae = torch.mean(torch.absolute(y_hat - y))
self.log("val_MAE", mae, on_epoch=True, sync_dist=True)
return {"val_loss": loss, "val_MAE": mae}

def predict_step(
self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
) -> torch.Tensor:
input_ids, attention_mask, y = batch
return self(input_ids, attention_mask)

def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(self.parameters(), lr=l_r)


# %% data
def main() -> None:
if os.path.exists(train_pad_cased_path):
print(f"file {train_pad_cased_path} exists")
train_inputs = pd.read_json(train_pad_cased_path)
train_outputs = get_train_data()

input_ids = torch.tensor(train_inputs["input_ids"])
attention_mask = torch.tensor(train_inputs["attention_mask"])
train_outputs = torch.tensor(train_outputs.values)
else:
warnings.warn("file doesn't exist", UserWarning, stacklevel=2)
sys.exit()

dataset = MyDataset(input_ids, attention_mask, train_outputs)
train_set, val_set = random_split(dataset, [0.9, 0.1])

train_loader = DataLoader(
train_set, batch_size=train_batch_size, shuffle=True, num_workers=2
)
val_loader = DataLoader(
val_set, batch_size=val_batch_size, shuffle=False, num_workers=2
)

# %% train

model = MatBert(bert_path)
model.cuda()
early_stopping = EarlyStopping(
monitor="val_MAE", patience=patience, verbose=True, mode="min"
)
check_point = ModelCheckpoint(
monitor="val_MAE",
save_top_k=save_top_k,
dirpath=f"checkpoints/model_epoch{epoch}_{task}",
filename="{epoch}_{val_MAE:.4f}_best_model",
mode="min",
)
trainer = l.Trainer(
max_epochs=epoch,
accelerator="gpu",
callbacks=[check_point, early_stopping],
log_every_n_steps=log_every_n_steps,
devices=-1,
strategy="ddp_find_unused_parameters_true",
)
model.train()
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)


# %%
if __name__ == "__main__":
Fire(main)