Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logistic regression model #15

Open
river-afk opened this issue Jan 21, 2020 · 8 comments
Open

Logistic regression model #15

river-afk opened this issue Jan 21, 2020 · 8 comments

Comments

@river-afk
Copy link

Hello @YangZhang4065 ,
Thanks for the nice work. I'm trying to train model to estimate the global label distribution. Can you give more details of the architecture and training scheme?
Thanks!

@YangZhang4065
Copy link
Owner

Hello @river-afk
Thank you for your interest. We were using a simple linear Keras MLP to estimate the global label distribution. So it is input > linear layer > softmax > output. Though nonlinearity might improve model performance.
The input to the model is the images' VGG features. Model's output would be corresponding images' label distribution. One label distribution is a vector whose summation is 1 (That is why there is a softmax in the model) and its length equals the number of unique segmentation classes. One label distribution literally describes the proportions of segmentation labels in one image.
The model is trained on the source dataset with an Adam optimizer and KL divergence loss. Note that we measure the Chi-squared distances between the model's prediction and label distribution ground-truth on the validation dataset to do an early stop.

@river-afk
Copy link
Author

Thank you @YangZhang4065 , it is much clearer now. Btw in the paper you said that the backbone was the Inception-ResNet-v2, not VGG? Also, how many epoches did you use to train the regression model?

@river-afk
Copy link
Author

Another question is that in the TPAMI paper, you use the cross-entropy loss to train logistic regression: "... we thus train it by replacing the one-hot vectors in the cross-entropy loss with the ground-truth label distribution ps, which is counted by using eq. (1) from the human labels of the source domain" (Sec 3.3.1). So is it KL or CE?

@YangZhang4065
Copy link
Owner

@river-afk Sorry for the late reply.
I apologize for the misinformation, we did use the Inception feature instead of the VGG feature to train the logistical regression. I did not use a constant epoch number. Instead, the training process stops if the validation Chi-square distance no long improves after N epochs. I remember I set N to 100. I also take the model with the best validation Chi-square.
CE = KL + (a constant term) in our context. So they should be the same in terms of optimizing our problem. Theoretically, you should use KL divergence. In practice, I used cross-entropy for better numerical stability.

@chccgiven
Copy link

I'm trying to implement your global label distribution algorithm, I did as you told above except that I use the images' VGG features as input, but when validation, I can just get value 0 of Chi-square, could you please give me some advice about this situation?

Thank u!

@YangZhang4065
Copy link
Owner

@chccgiven Well, Chi-square distance is 0 means that your validation prediction is identical to GT. this either means your preditor is perfect (not possible) or there is a bug somewhere.
I would suggest you manually check if label distribution predictor predictions and GTs are really the same. If no then you are not calculating the Chi-square distance right. If yes then there is something wrong with your data (You might be calculating the distance between the GT and the GT instead of GT vs. prediction or your training data wrongly contains validation data) or model.

@chccgiven
Copy link

Thank u for your kind reply!
Could you please tell me whether you used the function: chisq, p = chisquare(f_obs, f_exp, axis=0) (the function from scipy) to calculate the probability distribution distance between predictor predictions and GT? If not, can you tell me which function did you use?

@YangZhang4065
Copy link
Owner

Hi @chccgiven,

I did not use the scipy chisquare, though I believe it should perform the job pretty well. I was using an implementation in MATLAB back at that time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants