-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
NoDuplicatesDataLoader.py
46 lines (36 loc) · 1.52 KB
/
NoDuplicatesDataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from __future__ import annotations
import math
import random
class NoDuplicatesDataLoader:
def __init__(self, train_examples, batch_size):
"""
A special data loader to be used with MultipleNegativesRankingLoss.
The data loader ensures that there are no duplicate sentences within the same batch
"""
self.batch_size = batch_size
self.data_pointer = 0
self.collate_fn = None
self.train_examples = train_examples
random.shuffle(self.train_examples)
def __iter__(self):
for _ in range(self.__len__()):
batch = []
texts_in_batch = set()
while len(batch) < self.batch_size:
example = self.train_examples[self.data_pointer]
valid_example = True
for text in example.texts:
if text.strip().lower() in texts_in_batch:
valid_example = False
break
if valid_example:
batch.append(example)
for text in example.texts:
texts_in_batch.add(text.strip().lower())
self.data_pointer += 1
if self.data_pointer >= len(self.train_examples):
self.data_pointer = 0
random.shuffle(self.train_examples)
yield self.collate_fn(batch) if self.collate_fn is not None else batch
def __len__(self):
return math.floor(len(self.train_examples) / self.batch_size)