diff --git a/dsm/utilities.py b/dsm/utilities.py index 2c5396a..60bf65c 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -57,14 +57,10 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, risks=model.risks, optimizer=model.optimizer) premodel.double() - optimizer = get_optimizer(premodel, lr) - oldcost = float('inf') - patience = 0 - costs = [] + oldcost, patience = float('inf'), 0 for _ in tqdm(range(n_iter)): - optimizer.zero_grad() loss = 0 for r in range(model.risks): @@ -75,15 +71,18 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, valid_loss = 0 for r in range(model.risks): valid_loss += unconditional_loss(premodel, t_valid, e_valid, str(r+1)) - valid_loss = valid_loss.detach().cpu().numpy() - costs.append(valid_loss) - #print(valid_loss) - if np.abs(costs[-1] - oldcost) < thres: + valid_loss = valid_loss.item() + + if np.abs(valid_loss - oldcost) < thres: patience += 1 if patience == 3: break - oldcost = costs[-1] + elif oldcost > valid_loss: + patience = 0 + best_weight = deepcopy(premodel.state_dict()) + oldcost = valid_loss + premodel.load_state_dict(best_weight) return premodel def _reshape_tensor_with_nans(data): @@ -180,30 +179,15 @@ def train_dsm(model, elbo=False, risk=str(r+1)) - valid_loss = valid_loss.detach().cpu().numpy() - costs.append(float(valid_loss)) - dics.append(deepcopy(model.state_dict())) - - if costs[-1] >= oldcost: - if patience == 2: - minm = np.argmin(costs) - model.load_state_dict(dics[minm]) - - del dics - gc.collect() - - return model, i - else: - patience += 1 + valid_loss = valid_loss.item() + if valid_loss > oldcost: + patience += 1 + if patience == 3: + break else: patience = 0 + best_weights = deepcopy(model.state_dict()) + oldcost = valid_loss - oldcost = costs[-1] - - minm = np.argmin(costs) - model.load_state_dict(dics[minm]) - - del dics - gc.collect() - + model.load_state_dict(best_weights) return model, i