Skip to content

Commit

Permalink
Merge pull request #1 from realityengines/mnist
Browse files Browse the repository at this point in the history
Debias Vision
  • Loading branch information
Yash Savani authored Jul 14, 2020
2 parents 8a84428 + cbccb3a commit e0efc81
Show file tree
Hide file tree
Showing 11 changed files with 1,276 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@


data/*


# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
61 changes: 59 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,64 @@ Install the requirements using
$ pip install -r requirements.txt
```

## Run a Post-Hoc debiasing experiment
## How to debias your own neural network
### What you need:
* Model architecture (currently supported PyTorch Module)
* Model checkpoint (currently supported file type: .pt)
* Validation dataset that you are going to debias based on
* Very simple if your model uses one of these 10 datasets https://pytorch.org/docs/stable/torchvision/datasets.html
* For a custom dataset, make a quick DataLoader class using this tutorial https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
* Protected attribute (e.g., “race”)
* Prediction attribute (e.g., “smiling”)
* (optional) tunable bias parameter lambda
* lambda=0.9 means, I want my model to be almost fully debiased
* lambda=0.1 means, I care a lot about accuracy and minimally about debiasing
* We recommend lambda between 0.5 and 0.75

### Steps to debias your model

Follow the example given in `test_post_hoc_lib.py`. You will need to inherit from `DebiasModel` class in `post_hoc_lib.py` and overwrite the following methods with appropriate ones for your model:

* `def get_valloader(self):`
* This should return a `torch.utils.data.DataLoader` that iterates over your validation set. Each iteration should return an input batch and an output binary batch vector that includes the protected attribute and prediction attribute along some index.
* `def get_testloader(self):`
* This should return a `torch.utils.data.DataLoader` with the same parameters as the valloader, except with the test set. This is used for evaluation of the models
* `def protected_index(self):`
* This should return the index for your protected attribute
* `def prediction_index(self):`
* This should return the index for your prediction attribute
* `def get_model(self):`
* This should return the model architecture with the loaded weights from the checkpoint
* `def get_last_layer_name(self):`
* This should return the name of the last layer of the model.

You can also overwrite the lambda parameter and the bias measure you want to use for the objective in the `__init__` constructor for your class. To overwrite the lambda parameter change the `self.lam` attribue, to overwrite the bias measure change the `self.bias_measure` attribute to one of 'spd', 'eod', or 'aod'.

Once you have overwritten the methods and parameters, you can execute the following code

```
# CustomModel is a subclass of DebiasModel.
custom_model = CustomModel()
# This returns a dictionary containing bias statistics for the original model.
# If verbose is True, then it prints out the bias statistics.
orig_data = custom_model.evaluate_original(verbose=True)
# This runs the random debiasing algorithm on the model and returns
# the random debiased model and the random threshold that will minimize the objective.
rand_model, rand_thresh = custom_model.random_debias_model()
# This returns a dictionary containing bias statistics for the random debiased model.
# If verbose is True, then it prints out the bias statistics.
rand_data = custom_model.evaluate_random_debiased(verbose=True)
# This runs the adversarial debiasing algorithm on the model and returns
# the adversarial debiased model and the adversarial threshold that will minimize the objective.
adv_model, adv_thresh = custom_model.adversarial_debias_model()
# This returns a dictionary containing bias statistics for the adversarial debiased model.
# If verbose is True, then it prints out the bias statistics.
adv_data = custom_model.evaluate_adversarial_debiased(verbose=True)
```

## Run our Post-Hoc debiasing experiments

### Step 1 - Create Configs
Create a config yaml file required to run the experiment by running
Expand Down Expand Up @@ -84,4 +141,4 @@ To analyze the results of the experiments and get the plots shown below you can
To clean up the config directories run
```
$ bash cleanup.sh
```
```
194 changes: 194 additions & 0 deletions celebA.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualize debiasing experiments on CelebA"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import os\n",
"from os.path import join\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"\n",
"from post_hoc_celeba import load_celeba, get_resnet_model\n",
"from PIL import Image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"descriptions = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive',\n",
" 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',\n",
" 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair',\n",
" 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses',\n",
" 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',\n",
" 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes',\n",
" 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose',\n",
" 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling',\n",
" 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat',\n",
" 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie',\n",
" 'Young', 'White', 'Black', 'Asian', 'Index']\n",
"\n",
"def sigmoid(x):\n",
" return 1/(1 + np.exp(-x)) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def image_from_index(index, folder='~/post_hoc_debiasing/data/celeba/img_align_celeba/', show=False):\n",
" # given the index of the image, output the image\n",
" file = str(index).zfill(6)+'.jpg'\n",
" img = Image.open(join(os.path.expanduser(folder), file))\n",
" if show:\n",
" plt.imshow(img)\n",
" plt.show()\n",
" return img\n",
" \n",
"def imshow_group(imgs, n):\n",
" # plot multiple images at once\n",
" plt.figure(figsize=(20,10))\n",
" columns = n\n",
" \n",
" for i in range(n):\n",
" plt.subplot(1, columns, i + 1)\n",
" img = imgs[i]\n",
" #img = img.astype(int)\n",
" plt.axis('off')\n",
" plt.imshow(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def output_debiased_imgs(biased_net,\n",
" debiased_net,\n",
" loader,\n",
" protected_attr,\n",
" prediction_attr):\n",
" \"\"\"\n",
" Display images along with their biased and debiased predictions\n",
" \"\"\" \n",
" prediction_index = descriptions.index(prediction_attr)\n",
" protected_index = descriptions.index(protected_attr)\n",
" ind = descriptions.index('Index')\n",
"\n",
" outputs = []\n",
" total_batches = len(loader)\n",
" for batch_num, (inputs, labels) in enumerate(loader):\n",
" inputs, labels = inputs.to(device), labels.to(device)\n",
" biased_outputs = biased_net(inputs)[:, 0]\n",
" debiased_outputs = debiased_net(inputs)[:, 0]\n",
"\n",
" for i in range(len(inputs)):\n",
" img = image_from_index(labels[i][ind].item())\n",
" label = labels[i][prediction_index].item()\n",
" protected = labels[i][protected_index].item()\n",
" biased_output = sigmoid(biased_outputs[i].item())\n",
" debiased_output = sigmoid(debiased_outputs[i].item()) \n",
"\n",
" outputs.append([img, label, protected, biased_output, debiased_output])\n",
"\n",
" if batch_num % 10 == 0:\n",
" print('At', batch_num, '/', total_batches)\n",
"\n",
" return outputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load the test set\n",
"_, _, _, _, _, testloader = load_celeba(trainsize=0, \n",
" testsize=100, \n",
" num_workers=0, \n",
" batch_size=32,\n",
" transform_type='tensor')\n",
"\n",
"biased_model_path = 'models/by_random_checkpoint.pt'\n",
"debiased_model_path = 'models/by_checkpoint.pt'\n",
"\n",
"# load the biased and unbiased models\n",
"biased_net = get_resnet_model()\n",
"biased_net.load_state_dict(torch.load(biased_model_path, map_location=device))\n",
"\n",
"debiased_net = get_resnet_model()\n",
"debiased_net.load_state_dict(torch.load(debiased_model_path, map_location=device)['model_state_dict'])\n",
"\n",
"# output images\n",
"outputs = output_debiased_imgs(biased_net=biased_net,\n",
" debiased_net=debiased_net,\n",
" loader=testloader,\n",
" protected_attr = 'Black',\n",
" prediction_attr = 'Smiling')\n",
"imgs = [output[0] for output in outputs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rowsize = 8\n",
"for i in range(min(len(imgs)//rowsize, 5)):\n",
" imshow_group(imgs[rowsize*i:rowsize*(i+1)], rowsize)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
49 changes: 49 additions & 0 deletions celeb_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
import os
import torch
import torchvision
import numpy as np
from torchvision.datasets import CelebA
from torch.utils.data import Subset

white = np.load(os.path.expanduser('~/post_hoc_debiasing/celebrace/white_full.npy'))
black = np.load(os.path.expanduser('~/post_hoc_debiasing/celebrace/black_full.npy'))
asian = np.load(os.path.expanduser('~/post_hoc_debiasing/celebrace/asian_full.npy'))


class CelebRace(CelebA):

def __getitem__(self, index):

X, target = super().__getitem__(index)
ind = int(self.filename[index].split('.')[0])

augment = torch.tensor([white[ind-1] > .501,
black[ind-1] > .501,
asian[ind-1] > .501,
ind,
1-target[20]], dtype=torch.long)

return X, torch.cat((target, augment))


def unambiguous(dataset, split='train', thresh=.7):

if split == 'train':
n = 162770
else:
n = 19962
unambiguous_indices = [i for i in range(n) if (white[i] > thresh or black[i] > thresh or asian[i] > thresh)]

return Subset(dataset, unambiguous_indices)


def split_check(dataset, split='train', thresh=.7):

if split == 'train':
n = 162770
else:
n = 19962
unambiguous_indices = [i for i in range(n) if (asian[i] > thresh)]

return Subset(dataset, unambiguous_indices)
Binary file added celebrace/asian_full.npy
Binary file not shown.
Binary file added celebrace/black_full.npy
Binary file not shown.
Binary file added celebrace/white_full.npy
Binary file not shown.
23 changes: 23 additions & 0 deletions config_celeba.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
output: black_young.json
epochs: 100
trainsize: 30000
testsize: 1000
batch_size: 32
num_workers: 2
print_priors: True
protected_attr: Black
prediction_attr: Young
checkpoint: by_checkpoint.pt
retrain: False
models:
- random
- adversarial
random:
checkpoint: by_random_checkpoint.pt
adversarial:
epochs: 10
critic_steps: 300
actor_steps: 100
lambda: 0.75
checkpoint: by_adversarial_checkpoint.pt
Loading

0 comments on commit e0efc81

Please sign in to comment.