Skip to content

Commit

Permalink
Add _prob methods for getting at the raw probabilities.
Browse files Browse the repository at this point in the history
  • Loading branch information
cottrell committed Oct 2, 2024
1 parent c4f3cd7 commit 5e1432d
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 1 deletion.
140 changes: 139 additions & 1 deletion easyocr/easyocr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

from .recognition import get_recognizer, get_text
from .recognition import get_recognizer, get_text, get_text_prob
from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\
download_and_unzip, printProgressBar, diff, reformat_input,\
make_rotated_img_list, set_result_with_confidence,\
Expand Down Expand Up @@ -350,6 +350,93 @@ def detect(self, img, min_size = 20, text_threshold = 0.7, low_text = 0.4,\

return horizontal_list_agg, free_list_agg

def recognize_prob(self, img_cv_grey, horizontal_list=None, free_list=None,\
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
rotation_info = None,paragraph = False,\
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
y_ths = 0.5, x_ths = 1.0, reformat=True, output_format='standard'):

if reformat:
img, img_cv_grey = reformat_input(img_cv_grey)

if allowlist:
ignore_char = ''.join(set(self.character)-set(allowlist))
elif blocklist:
ignore_char = ''.join(set(blocklist))
else:
ignore_char = ''.join(set(self.character)-set(self.lang_char))

if self.model_lang in ['chinese_tra','chinese_sim']: decoder = 'greedy'

if (horizontal_list==None) and (free_list==None):
y_max, x_max = img_cv_grey.shape
horizontal_list = [[0, x_max, 0, y_max]]
free_list = []

# without gpu/parallelization, it is faster to process image one by one
if ((batch_size == 1) or (self.device == 'cpu')) and not rotation_info:
result = []
for bbox in horizontal_list:
h_list = [bbox]
f_list = []
image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH)
result0 = get_text_prob(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
workers, self.device)
result += result0
for bbox in free_list:
h_list = []
f_list = [bbox]
image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH)
result0 = get_text_prob(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
workers, self.device)
result += result0
# default mode will try to process multiple boxes at the same time
else:
image_list, max_width = get_image_list(horizontal_list, free_list, img_cv_grey, model_height = imgH)
image_len = len(image_list)
if rotation_info and image_list:
image_list = make_rotated_img_list(rotation_info, image_list)
max_width = max(max_width, imgH)

result = get_text_prob(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
workers, self.device)

if rotation_info and (horizontal_list+free_list):
# Reshape result to be a list of lists, each row being for
# one of the rotations (first row being no rotation)
result = set_result_with_confidence(
[result[image_len*i:image_len*(i+1)] for i in range(len(rotation_info) + 1)])

if self.model_lang == 'arabic':
direction_mode = 'rtl'
result = [list(item) for item in result]
for item in result:
item[1] = get_display(item[1])
else:
direction_mode = 'ltr'

if paragraph:
result = get_paragraph(result, x_ths=x_ths, y_ths=y_ths, mode = direction_mode)

if detail == 0:
return [item[1] for item in result]
elif output_format == 'dict':
if paragraph:
return [ {'boxes':item[0],'text':item[1]} for item in result]
return [ {'boxes':item[0],'text':item[1],'confident':item[2]} for item in result]
elif output_format == 'json':
if paragraph:
return [json.dumps({'boxes':[list(map(int, lst)) for lst in item[0]],'text':item[1]}, ensure_ascii=False) for item in result]
return [json.dumps({'boxes':[list(map(int, lst)) for lst in item[0]],'text':item[1],'confident':item[2]}, ensure_ascii=False) for item in result]
elif output_format == 'free_merge':
return merge_to_free(result, free_list)
else:
return result

def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\
decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
Expand Down Expand Up @@ -472,6 +559,42 @@ def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
filter_ths, y_ths, x_ths, False, output_format)

return result

def readtext_prob(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
rotation_info = None, paragraph = False, min_size = 20,\
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\
canvas_size = 2560, mag_ratio = 1.,\
slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\
width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1,
threshold = 0.2, bbox_min_score = 0.2, bbox_min_size = 3, max_candidates = 0,
output_format='standard'):
'''
Parameters:
image: file path or numpy-array or a byte stream object
'''
img, img_cv_grey = reformat_input(image)

horizontal_list, free_list = self.detect(img,
min_size = min_size, text_threshold = text_threshold,\
low_text = low_text, link_threshold = link_threshold,\
canvas_size = canvas_size, mag_ratio = mag_ratio,\
slope_ths = slope_ths, ycenter_ths = ycenter_ths,\
height_ths = height_ths, width_ths= width_ths,\
add_margin = add_margin, reformat = False,\
threshold = threshold, bbox_min_score = bbox_min_score,\
bbox_min_size = bbox_min_size, max_candidates = max_candidates
)
# get the 1st result from hor & free list as self.detect returns a list of depth 3
horizontal_list, free_list = horizontal_list[0], free_list[0]
result = self.recognize_prob(img_cv_grey, horizontal_list, free_list,\
decoder, beamWidth, batch_size,\
workers, allowlist, blocklist, detail, rotation_info,\
paragraph, contrast_ths, adjust_contrast,\
filter_ths, y_ths, x_ths, False, output_format)

return result

def readtextlang(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
Expand Down Expand Up @@ -577,3 +700,18 @@ def readtext_batched(self, image, n_width=None, n_height=None,\
filter_ths, y_ths, x_ths, False, output_format))

return result_agg


def convert_prob_to_word(prob, converter):
"""
For use with the readtest_prob outputs.
- prob should be 2d
- convert = reader.converter
"""
assert prob.ndim == 2
preds_index = np.argmax(prob, axis=1)
preds_index = preds_index.flatten()
preds_size = np.array([prob.shape[0]])
preds_str = converter.decode_greedy(preds_index, preds_size)[0]
return preds_str
91 changes: 91 additions & 0 deletions easyocr/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,48 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\

return result

def recognizer_predict_prob(model, converter, test_loader, batch_max_length,\
ignore_idx, char_group_idx, decoder = 'greedy', beamWidth= 5, device = 'cpu'):
model.eval()
result = []
with torch.no_grad():
for image_tensors in test_loader:
batch_size = image_tensors.size(0)
image = image_tensors.to(device)
# For max length prediction
length_for_pred = torch.IntTensor([batch_max_length] * batch_size).to(device)
text_for_pred = torch.LongTensor(batch_size, batch_max_length + 1).fill_(0).to(device)

preds = model(image, text_for_pred)

# Select max probabilty (greedy decoding) then decode index to character
preds_size = torch.IntTensor([preds.size(1)] * batch_size)

######## filter ignore_char, rebalance
preds_prob = F.softmax(preds, dim=2)
preds_prob = preds_prob.cpu().detach().numpy()
preds_prob[:,:,ignore_idx] = 0.
pred_norm = preds_prob.sum(axis=2)
preds_prob = preds_prob/np.expand_dims(pred_norm, axis=-1)
preds_prob = torch.from_numpy(preds_prob).float().to(device)
preds_prob = preds_prob.cpu().detach().numpy()

values = preds_prob.max(axis=2)
indices = preds_prob.argmax(axis=2)
preds_max_prob = []
for v,i in zip(values, indices):
max_probs = v[i!=0] # this removes blanks
if len(max_probs)>0:
preds_max_prob.append(max_probs)
else:
preds_max_prob.append(np.array([0]))

for pred_max_prob in preds_max_prob:
confidence_score = custom_mean(pred_max_prob)
result.append([preds_prob, confidence_score])

return result

def get_recognizer(recog_network, network_params, character,\
separator_list, dict_list, model_path,\
device = 'cpu', quantize = True):
Expand Down Expand Up @@ -231,3 +273,52 @@ def get_text(character, imgH, imgW, recognizer, converter, image_list,\
result.append( (box, pred1[0], pred1[1]) )

return result

def get_text_prob(character, imgH, imgW, recognizer, converter, image_list,\
ignore_char = '',decoder = 'greedy', beamWidth =5, batch_size=1, contrast_ths=0.1,\
adjust_contrast=0.5, filter_ths = 0.003, workers = 1, device = 'cpu'):
batch_max_length = int(imgW/10)

char_group_idx = {}
ignore_idx = []
for char in ignore_char:
try: ignore_idx.append(character.index(char)+1)
except: pass

coord = [item[0] for item in image_list]
img_list = [item[1] for item in image_list]
AlignCollate_normal = AlignCollate(imgH=imgH, imgW=imgW, keep_ratio_with_pad=True)
test_data = ListDataset(img_list)
test_loader = torch.utils.data.DataLoader(
test_data, batch_size=batch_size, shuffle=False,
num_workers=int(workers), collate_fn=AlignCollate_normal, pin_memory=True)

# predict first round
result1 = recognizer_predict_prob(recognizer, converter, test_loader,batch_max_length,\
ignore_idx, char_group_idx, decoder, beamWidth, device = device)

# predict second round
low_confident_idx = [i for i,item in enumerate(result1) if (item[1] < contrast_ths)]
if len(low_confident_idx) > 0:
img_list2 = [img_list[i] for i in low_confident_idx]
AlignCollate_contrast = AlignCollate(imgH=imgH, imgW=imgW, keep_ratio_with_pad=True, adjust_contrast=adjust_contrast)
test_data = ListDataset(img_list2)
test_loader = torch.utils.data.DataLoader(
test_data, batch_size=batch_size, shuffle=False,
num_workers=int(workers), collate_fn=AlignCollate_contrast, pin_memory=True)
result2 = recognizer_predict_prob(recognizer, converter, test_loader, batch_max_length,\
ignore_idx, char_group_idx, decoder, beamWidth, device = device)

result = []
for i, zipped in enumerate(zip(coord, result1)):
box, pred1 = zipped
if i in low_confident_idx:
pred2 = result2[low_confident_idx.index(i)]
if pred1[1]>pred2[1]:
result.append( (box, pred1[0], pred1[1]) )
else:
result.append( (box, pred2[0], pred2[1]) )
else:
result.append( (box, pred1[0], pred1[1]) )

return result

0 comments on commit 5e1432d

Please sign in to comment.