Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch_compute_gradient_portion #312

Open
wants to merge 2 commits into
base: development
Choose a base branch
from

Conversation

AminHP
Copy link
Contributor

@AminHP AminHP commented Feb 19, 2023

In the train_step method of the Model class, all samples in the batch are used for updating the network weights. In many problems, this is the correct and common way to update weights. However, in some situations, there are other ways for updating network weights. Sometimes, we can compute loss using all samples in the batch but update the network weights using only a subset of that batch. This is useful when we need to have a large batch but the memory limits don't allow us to compute the gradients for all samples. Hence, a solution is to compute the loss value with all batch samples, and then compute the gradients and update network weights via a subset of those samples. In most problems, this is not meaningful, but in Metric Learning and similar problems, it is quite useful.

For example, consider we are using the MultiSimilarityLoss in this example. In this loss function, each batch contains K samples from C classes. The purpose of MultiSimilarityLoss consists in computing loss in a way that the K samples belonging to a class get close in the feature space, while at the same time these K samples get distant from other classes' samples in that batch. This loss function needs a large batch to perform satisfactorily. However, computing the gradients for all samples in a large batch wouldn't fit in memory. A simple solution, as explained, is computing gradients and updating network weights for only m classes from those C classes in the batch. Although the loss value is calculated based on all samples in the batch, it is possible to update weights based on a portion of that batch.

I've tested the proposed implementation in an enterprise project, and it's been successful. Besides the memory concerns, it also helps to prevent overfitting. I've tested multiple scenarios in which the batch size was kept constant, but the batch_compute_gradient_portion value changed from 0.25 to 1. The models trained with batch_compute_gradient_portion < 1 were clearly less influenced by overfitting. Their generalization ability in predicting completely new samples with different distributions was significantly higher.

@AminHP
Copy link
Contributor Author

AminHP commented Feb 20, 2023

your consideration would be much appreciated @owenvallis.

@ebursztein
Copy link
Member

@AminHP thanks for the contribution - the idea is definitively interesting and I think that works to limit the backprop. Similarity needs the largest batch possible so whatever we can do to push further seems good.

What I am unsure is if taking for N is the best/only approach. We could imagine having a strategy maybe take samples with the highest distances - might not be as trivial to implement as we need to matrix but technically we have it so worst discussing.

Also I would think we would need to expose the "full" batch loss potentially as an optional metric to help understand how much deviation is observed.

Last but not least a small tutorial to explain the reasoning would be good

@owenvallis
Copy link
Collaborator

@AminHP apologies for taking a minute to respond here and thanks for the PR. This seems like an interesting way to support larger batch sizes and I like that the proposed changes keep the existing behavior as default. I agree with Elie as well, we might want to support other subset sampling strategies here, in which case it might be nice to pass in a batch subset sampling strategy instead of hardcoding the split at K. At the very least, I think we would want to shuffle the batch in the train step to try and get a random set of classes as the current TF Sim samplers tend to have the batch classes in contiguous blocks.

The other interesting note here is that this seems to imply that the majority of the learning comes from a subset of examples in each batch, which I think lends further weight to Elie's suggestion that we might want to take the subset based on distance... although the distances are currently computed inside the loss, which is inside the gradient tape scope, so that would take some work to figure out.

What's your thoughts @AminHP

@AminHP
Copy link
Contributor Author

AminHP commented Mar 15, 2023

Thanks @ebursztein and @owenvallis for reviewing this PR and for the thoughtful comments and opinions. I have been working on the tutorial but I haven't had enough time to complete it. The idea of the sampling strategy is quite exciting. I also have some questions to elaborate on your comments:

  1. Why would we need to shuffle the batch inside the train_step when they are already shuffled here?

  2. As far as I know, some loss functions such as MultiSimilarityLoss perform the subsampling process based on samples' distances. Adding a subsampling strategy based on distances wouldn't make inconsistencies with these kinds of loss functions?

  3. About the last note, a relatively slow solution might be as follows. At first, we compute all embeddings before the tape scope and candidate some samples based on their distances. Then, we recompute the embeddings for the candidates inside the tape scope in order to update the weights. However, as I mentioned above, this might have some inconsistencies with the sample mining process in loss functions.

@owenvallis
Copy link
Collaborator

Hello @AminHP,

Regarding the shuffling that you mentioned, it shuffles the entire dataset before we start sampling batches, but batches will still add the samples per class in contiguous blocks within each batch (see here). This may lead the current solution to select a subset of the classes in each batch, which is not ideal. Ideally, we should take a random sample of the batch so that we get some loss from as many classes as possible.

As for the multisim loss, my understanding is that it does not explicitly subsample the examples by distance but rather applies a weighted sum based on distance. This could result in a very small weight being applied to the easiest examples, but this should only affect the easiest pos/neg examples. I think the weighting that @ebursztein mentioned would have to be something separate, such as the distance to the hardest neighbor in the row. Alternatively, we could also keep the top K loss values in the output, but I am not clear on how we can do that. Perhaps we can compute the loss twice, once before the gradient and then again on the subset within the gradient tape scope.

I also like your third suggestion, and it seems like it would give us the most control. I am in favor of testing it out and confirming that it works across all losses, and that it gives the expected performance boost. On that note, I have been working on some benchmarking tools. They are not fully complete yet, but they are available in the development branch. If you are interested in setting up some experiments, please let me know. They are currently set up to use the cars196 and CUB datasets. The dataset generation and hyper tuning are in a good state, but I need to make a few updates to the main benchmark function.

@AminHP
Copy link
Contributor Author

AminHP commented Mar 21, 2023

Hi @owenvallis,

I understand your point about the need for shuffling, but I was actually thinking the opposite. I thought maybe taking all samples from a subset of classes into account (the current implementation) could provide a better understanding to the model. The model would learn how to make all samples from the same class closer, while samples from different classes get farther. Your point of view sounds intriguing. I should test it on my project and investigate which approach performs better. It probably takes 1 to 4 weeks.

The ideas on subsampling strategies are getting more and more interesting. So far, we have got three ideas:

  • Random selection of K samples from the batch for updating the weights.
  • Selecting top K loss values from the batch.
  • Selecting the K hardest neighbors (or some method like that).

I sorted the ideas based on their complexity to design and implement. I'm very keen on comparing and testing the ideas using the benchmarking tools you mentioned.

Overall, I think before going any further and developing more subsampling strategies, it might be more practical to set up some experiments to start with and inspect the first idea, which is already implemented in this PR. Then we can refactor the code and design the generalized approach of the subsampling strategy.

What's your opinion?

@AminHP
Copy link
Contributor Author

AminHP commented Mar 28, 2023

I've conducted several experiments on the shuffling subject (to be more precise, I call it batch permutation in the following). It turned out that the batch permutation could not improve the training process and loss reduction in my project. However, I think it depends on the problem and dataset.

In metric learning, some datasets consist of determined classes where the definition of each class is clear (like CIFAR and MNIST). Usually, in these kinds of datasets, the number of classes is small and the number of samples in each class is large. In such datasets, the proper action might be applying the batch permutation as a subsample of each class can represent it adequately (but I haven't tested it yet).

On the other hand, some datasets contain thousands of classes, but the definition of each class is not clear. Each class differs from the others based on some relatively unknown features and metrics. Those metrics that separate these classes rely on the connections between all samples of each class. Therefore, selecting a larger number of samples per class in these problems leads to better results. The training process in these problems performs better without applying the batch permutation. Note that my project belongs to this type.

After all, in order to generalize the code, I added a parameter, named batch_random_permutation, which applies the batch permutation in the train_step method. The default value of this parameter is False.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants