Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Merge pull request #9 from younader/docker
Browse files Browse the repository at this point in the history
Docker + JPG training
  • Loading branch information
schillij95 authored Aug 5, 2024
2 parents 58991e7 + 6c674ab commit 93b4b6f
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 45 deletions.
54 changes: 35 additions & 19 deletions prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import cv2
import numpy as np
def read_image_mask(fragment_id,start_idx=15,end_idx=45):

fragment_id_ = fragment_id.split("_")[0]
images = []

# idxs = range(65)
Expand All @@ -11,9 +11,12 @@ def read_image_mask(fragment_id,start_idx=15,end_idx=45):
idxs = range(start_idx, end_idx)

for i in idxs:

image = cv2.imread( f"train_scrolls/{fragment_id}/layers/{i:02}.tif", 0)

if os.path.exists(f"train_scrolls/{fragment_id}/layers/{i:02}.tif"):
image = cv2.imread(f"train_scrolls/{fragment_id}/layers/{i:02}.tif", 0)
print(np.max(image))
else:
image = cv2.imread( f"train_scrolls/{fragment_id}/layers/{i:02}.jpg", 0)
print(np.max(image))
pad0 = (256 - image.shape[0] % 256)
pad1 = (256 - image.shape[1] % 256)

Expand All @@ -24,40 +27,53 @@ def read_image_mask(fragment_id,start_idx=15,end_idx=45):
image=np.clip(image,0,200)
images.append(image)
images = np.stack(images, axis=2)
if fragment_id in ['20230701020044','verso','20230901184804','20230901234823','20230531193658','20231007101615','20231005123333','20231011144857','20230522215721', '20230919113918', '20230625171244','20231022170900','20231012173610','20231016151000']:
if fragment_id_ in ['20230701020044','verso','20230901184804','20230901234823','20230531193658','20231007101615','20231005123333','20231011144857','20230522215721', '20230919113918', '20230625171244','20231022170900','20231012173610','20231016151000']:
images=images[:,:,::-1]
if fragment_id in ['20231022170901','20231022170900']:
mask = cv2.imread( f"train_scrolls/{fragment_id}/{fragment_id}_inklabels.tiff", 0)
if fragment_id_ in ['20231022170901','20231022170900']:
mask = cv2.imread( f"train_scrolls/{fragment_id}/{fragment_id_}_inklabels.tiff", 0)
else:
mask = cv2.imread(f"train_scrolls/{fragment_id}/{fragment_id}_inklabels.png", 0)
mask = cv2.imread(f"train_scrolls/{fragment_id}/{fragment_id_}_inklabels.png", 0)

fragment_mask=cv2.imread( f"train_scrolls/{fragment_id}/{fragment_id}_mask.png", 0)
fragment_mask=cv2.imread( f"train_scrolls/{fragment_id}/{fragment_id_}_mask.png", 0)
fragment_mask = np.pad(fragment_mask, [(0, pad0), (0, pad1)], constant_values=0)
mask = mask.astype('float32')
mask/=255
return images, mask,fragment_mask

def run_sanity_checks():
for fragment_id in ['20231210121321','20231106155350','20231005123336','20230820203112','20230620230619','20230826170124','20230702185753','20230522215721','20230531193658','20230520175435','20230903193206','20230902141231','20231007101615','20230929220924','recto','verso','20231016151000','20231012184423','20231031143850']:
fragment_id_ = "_".join(fragment_id.split("_")[:min(1, len(fragment_id)-1)])
print(fragment_id)
assert os.path.exists(f'train_scrolls/{fragment_id}/layers/00.tif')
assert os.path.exists(f'train_scrolls/{fragment_id}/{fragment_id}_inklabels.png')
assert os.path.exists(f'train_scrolls/{fragment_id}/{fragment_id}_mask.png')
if not os.path.exists(f'train_scrolls/{fragment_id_}'):
fragment_id_ += "_superseded"
assert os.path.exists(f'train_scrolls/{fragment_id_}/layers/00.tif') or os.path.exists(f'train_scrolls/{fragment_id_}/layers/00.jpg'), f"Fragment id {fragment_id_} has no surface volume"
assert os.path.exists(f'train_scrolls/{fragment_id_}/{fragment_id}_inklabels.png')
assert os.path.exists(f'train_scrolls/{fragment_id_}/{fragment_id}_mask.png')
assert os.path.exists(f'train_scrolls/20231022170901/layers/00.tif')
assert os.path.exists(f'train_scrolls/20231022170901/20231022170901_inklabels.tiff')
assert os.path.exists(f'train_scrolls/20231022170901/20231022170901_mask.png')
def prepare_data():
for l in os.listdir('all_labels/'):
if '.png' in l:
f_id=l[:-14]
if os.path.exists(f'train_scrolls/{f_id}'):
f_id = l[:-14]
f_id_ = f_id
if not os.path.exists(f'train_scrolls/{f_id}'):
f_id_ = f_id + "_superseded"
if os.path.exists(f'train_scrolls/{f_id_}'):
img=cv2.imread(f'all_labels/{f_id}_inklabels.png', 0)
cv2.imwrite(f"train_scrolls/{f_id}/{f_id}_inklabels.png", img)
cv2.imwrite(f"train_scrolls/{f_id_}/{f_id}_inklabels.png", img)
else:
print(f"couldnt find {f_id_}")
if '.tiff' in l:
f_id=l[:-15]
if os.path.exists(f'train_scrolls/{f_id}'):
f_id = l[:-15]
f_id_ = f_id
if not os.path.exists(f'train_scrolls/{f_id}'):
f_id_ = f_id + "_superseded"
if os.path.exists(f'train_scrolls/{f_id_}'):
img=cv2.imread(f'all_labels/{f_id}_inklabels.tiff', 0)
cv2.imwrite(f"train_scrolls/{f_id}/{f_id}_inklabels.tiff", img)
cv2.imwrite(f"train_scrolls/{f_id_}/{f_id}_inklabels.tiff", img)
else:
print(f"couldnt find {f_id_}")
if __name__ == "__main__":
prepare_data()
run_sanity_checks()
run_sanity_checks()
16 changes: 13 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,24 @@ EASY: build the docker image:

```bash
docker build -t gp_model .
docker run --gpus all --shm-size=150g -it -v </your-path-to-train-scrolls>:/workspace/train_scrolls youssef_gp
```

Or: using a docker image like `pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel` for your development environment. Kaggle/Colab images should work fine as well.
Then to train:

Then to install this project inside the docker image, run:
```bash
python train_timesformer_og.py
```

Or to run inference with the already trained model:

```bash
python inference_timesformer.py --model_path timesformer_weights.ckpt --segment_path train_scrolls --segment_id 20231005123336
```

Important note: to install the ink labels and training data inside the docker image, run:

```bash
pip install -r requirements.txt
#to download the segments from the server
./download.sh
#propagates the inklabels into the respective segment folders for training
Expand Down
34 changes: 20 additions & 14 deletions train_timesformer_deduped.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def cfg_init(cfg, mode='train'):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def read_image_mask(fragment_id,start_idx=17,end_idx=43):

fragment_id_ = fragment_id.split("_")[0]
images = []

# idxs = range(65)
Expand All @@ -213,8 +213,10 @@ def read_image_mask(fragment_id,start_idx=17,end_idx=43):
idxs = range(start_idx, end_idx)

for i in idxs:

image = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.tif", 0)
if os.path.exists(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.tif"):
image = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.tif", 0)
else:
image = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.jpg", 0)

pad0 = (CFG.tile_size - image.shape[0] % CFG.tile_size)
pad1 = (CFG.tile_size - image.shape[1] % CFG.tile_size)
Expand All @@ -226,20 +228,20 @@ def read_image_mask(fragment_id,start_idx=17,end_idx=43):
if 'frag' in fragment_id:
image = cv2.resize(image, (image.shape[1]//2,image.shape[0]//2), interpolation = cv2.INTER_AREA)
image=np.clip(image,0,200)
if fragment_id=='20230827161846':
if fragment_id_=='20230827161846':
image=cv2.flip(image,0)
images.append(image)
images = np.stack(images, axis=2)
if fragment_id in ['20230701020044','verso','20230901184804','20230901234823','20230531193658','20231007101615','20231005123333','20231011144857','20230522215721', '20230919113918', '20230625171244','20231022170900','20231012173610','20231016151000']:
if fragment_id_ in ['20230701020044','verso','20230901184804','20230901234823','20230531193658','20231007101615','20231005123333','20231011144857','20230522215721', '20230919113918', '20230625171244','20231022170900','20231012173610','20231016151000']:

images=images[:,:,::-1]
if fragment_id in ['20231022170901','20231022170900']:
mask = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id}_inklabels.tiff", 0)
if fragment_id_ in ['20231022170901','20231022170900']:
mask = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id_}_inklabels.tiff", 0)
else:
mask = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id}_inklabels.png", 0)
mask = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id_}_inklabels.png", 0)
# mask = np.pad(mask, [(0, pad0), (0, pad1)], constant_values=0)
fragment_mask=cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id}_mask.png", 0)
if fragment_id=='20230827161846':
fragment_mask=cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id_}_mask.png", 0)
if fragment_id_=='20230827161846':
fragment_mask=cv2.flip(fragment_mask,0)

fragment_mask = np.pad(fragment_mask, [(0, pad0), (0, pad1)], constant_values=0)
Expand All @@ -264,10 +266,14 @@ def get_train_valid_dataset():
#BIG 6:'20231005123333','20231022170900','20231012173610','20230702185753','20230929220924','20231007101615'
for fragment_id in ['20231210121321','20231106155350','20231005123336','20230820203112','20230620230619','20230826170124','20230702185753','20230522215721','20230531193658','20230520175435','20230903193206','20230902141231','20231007101615','20230929220924','recto','verso','20231016151000','20231012184423','20231031143850']:
#,

# for fragment_id in ['20231210121321','20231106155350']:
if not os.path.exists(f"train_scrolls/{fragment_id}"):
fragment_id = fragment_id + "_superseded"
# for fragment_id in ['20231210121321','20231106155350']:
print('reading ',fragment_id)
image, mask,fragment_mask = read_image_mask(fragment_id)
try:
image, mask,fragment_mask = read_image_mask(fragment_id)
except:
print(f"couldnt load {fragment_id}!")
x1_list = list(range(0, image.shape[1]-CFG.tile_size+1, CFG.stride))
y1_list = list(range(0, image.shape[0]-CFG.tile_size+1, CFG.stride))

Expand Down Expand Up @@ -619,4 +625,4 @@ def scheduler_step(scheduler, avg_val_loss, epoch):
del train_images,train_loader,train_masks,valid_loader,model
gc.collect()
torch.cuda.empty_cache()
wandb.finish()
wandb.finish()
28 changes: 19 additions & 9 deletions train_timesformer_og.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,28 @@ def cfg_init(cfg, mode='train'):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def read_image_mask(fragment_id,start_idx=17,end_idx=43):

fragment_id_ = fragment_id.split("_")[0]
images = []
idxs = range(start_idx, end_idx)

for i in idxs:
image = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.tif", 0)
if os.path.exists(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.tif"):
image = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.tif", 0)
else:
image = cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/layers/{i:02}.jpg", 0)
pad0 = (CFG.tile_size - image.shape[0] % CFG.tile_size)
pad1 = (CFG.tile_size - image.shape[1] % CFG.tile_size)
image = np.pad(image, [(0, pad0), (0, pad1)], constant_values=0)
image=np.clip(image,0,200)
images.append(image)
images = np.stack(images, axis=2)
if fragment_id in ['20230701020044','verso','20230901184804','20230901234823','20230531193658','20231007101615','20231005123333','20231011144857','20230522215721', '20230919113918', '20230625171244','20231022170900','20231012173610','20231016151000']:
if fragment_id_ in ['20230701020044','verso','20230901184804','20230901234823','20230531193658','20231007101615','20231005123333','20231011144857','20230522215721', '20230919113918', '20230625171244','20231022170900','20231012173610','20231016151000']:
images=images[:,:,::-1]
if fragment_id in ['20231022170901','20231022170900']:
mask = cv2.imread( f"train_scrolls/{fragment_id}/{fragment_id}_inklabels.tiff", 0)
if fragment_id_ in ['20231022170901','20231022170900']:
mask = cv2.imread( f"train_scrolls/{fragment_id}/{fragment_id_}_inklabels.tiff", 0)
else:
mask = cv2.imread( f"train_scrolls/{fragment_id}/{fragment_id}_inklabels.png", 0)
fragment_mask=cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id}_mask.png", 0)
mask = cv2.imread( f"train_scrolls/{fragment_id}/{fragment_id_}_inklabels.png", 0)
fragment_mask=cv2.imread(CFG.comp_dataset_path + f"train_scrolls/{fragment_id}/{fragment_id_}_mask.png", 0)
fragment_mask = np.pad(fragment_mask, [(0, pad0), (0, pad1)], constant_values=0)
mask = mask.astype('float32')
mask/=255
Expand All @@ -157,8 +160,15 @@ def get_train_valid_dataset():
valid_xyxys = []

for fragment_id in ['20231210121321','20231022170901','20231106155351','20231005123336','20230820203112','20230826170124','20230702185753','20230522215721','20230531193658','20230903193206','20230902141231','20231007101615','20230929220926','recto','20231016151000','20231012184423','20231031143850']:
# for fragment_id in ['20231210121321', '20231022170901']:
if not os.path.exists(f"train_scrolls/{fragment_id}"):
fragment_id = fragment_id + "_superseded"
print('reading ',fragment_id)
image, mask,fragment_mask = read_image_mask(fragment_id)
try:
image, mask,fragment_mask = read_image_mask(fragment_id)
except Exception as e:
print(f"couldnt load {fragment_id}: {str(e)}!")
continue
x1_list = list(range(0, image.shape[1]-CFG.tile_size+1, CFG.stride))
y1_list = list(range(0, image.shape[0]-CFG.tile_size+1, CFG.stride))
windows_dict={}
Expand Down Expand Up @@ -399,7 +409,7 @@ def scheduler_step(scheduler, avg_val_loss, epoch):
trainer = pl.Trainer(
max_epochs=20,
accelerator="gpu",
devices=4,
devices=-1,
logger=wandb_logger,
default_root_dir="./models",
accumulate_grad_batches=1,
Expand Down

0 comments on commit 93b4b6f

Please sign in to comment.