

Batching matters a ton for speed. We want to have very evenly divided batches, with absolutely minimal padding. To do this we have to hack a bit around the default torchtext batching. This code patches their default batching to make sure we search over enough sentences to find tight batches.






假设你的RAM每个iteration可以处理1500个tokens, batch_size = 20, 那么只有当batch中的序列长度为sequence length = 1500 / 20 = 75时,才可以将计算资源利用完全。

现实中,每个batch的sequence length的显然是在变化的,那么如果希望尽量多的利用计算资源,就需要可以动态调整当前的batch_size.


 class MyIterator(data.Iterator):
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
for p in data.batch(d, self.batch_size * 100):
p_batch = data.batch(
sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(, self.random_shuffler) else:
self.batches = []
for b in data.batch(, self.batch_size,
self.batches.append(sorted(b, key=self.sort_key)) def rebatch(pad_idx, batch):
"Fix order in torchtext to match ours"
src, trg = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
return Batch(src, trg, pad_idx)


其中pool函数的功能与中定义的class BucketIterator(Iterator)的pool函数功能类似。

1. 将原始的data分成大小为 100 * batch_size的一些chunks => (以上迭代 p 即为 每个chunk)

2. 在每个chunk中根据 sort_key 对examples进行排序,并对每个chunk按照batch_size分成100个batch =>

( p_batch = data.batch( sorted(p, key=self.sort_key), self.batch_size, self.batch_size_fn) )

3. 将这些chunks进行shuffle  => (random_shuffler(list(p_batch)))

4. 在每个chunk中再把examples分成 大小为 batch_size 的 100 个 batch => (以上 b 即为每个 batch)

5. 生成器每次 yield一个batch  => (yield b)

