fastNLP.core.sampler¶
sampler 子类实现了 fastNLP 所需的各种采样器。
-
class
fastNLP.core.sampler.
Sampler
[源代码]¶ 别名
fastNLP.Sampler
fastNLP.core.sampler.Sampler
Sampler 类的基类. 规定以何种顺序取出data中的元素
子类必须实现
__call__
方法. 输入 DataSet 对象, 返回其中元素的下标序列-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
-
class
fastNLP.core.sampler.
BucketSampler
(num_buckets=10, batch_size=None, seq_len_field_name='seq_len')[源代码]¶ -
别名
fastNLP.BucketSampler
fastNLP.core.sampler.BucketSampler
带Bucket的 Random Sampler. 可以随机地取出长度相似的元素
-
class
fastNLP.core.sampler.
SequentialSampler
[源代码]¶ -
别名
fastNLP.SequentialSampler
fastNLP.core.sampler.SequentialSampler
顺序取出元素的 Sampler
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
-
class
fastNLP.core.sampler.
RandomSampler
[源代码]¶ -
别名
fastNLP.RandomSampler
fastNLP.core.sampler.RandomSampler
随机化取元素的 Sampler
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
-
class
fastNLP.core.sampler.
SortedSampler
(seq_len_field_name='seq_len', descending=True)[源代码]¶ -
别名
fastNLP.SortedSampler
fastNLP.core.sampler.SortedSampler
按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding)
-
class
fastNLP.core.sampler.
ConstantTokenNumSampler
(seq_len, max_token=4096, max_sentence=- 1, need_be_multiple_of=1, num_bucket=- 1)[源代码]¶ 别名
fastNLP.ConstantTokenNumSampler
fastNLP.core.sampler.ConstantTokenNumSampler
尽量保证每个batch的输入token数量是接近的。
使用示例 >>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量 >>> from fastNLP import DataSetIter, Trainer >>> sampler = ConstantTokenNumSampler(tr_data.get_field('seq_len').content, max_token=4096) >>> tr_iter = DataSetIter(tr_data, >>> batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, >>> drop_last=False, timeout=0, worker_init_fn=None, >>> batch_sampler=sampler) >>> >>> # 直接将tr_iter传入Trainer中,此时batch_size参数的值会被忽略 >>> trainer = Trainer(tr_iter, model, optimizer=optimizer, loss=TranslationLoss(), >>> batch_size=1, sampler=None, drop_last=False, update_every=1)
-
__init__
(seq_len, max_token=4096, max_sentence=- 1, need_be_multiple_of=1, num_bucket=- 1)[源代码]¶ - 参数
seq_len (List[int]) -- list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入
max_token (int) -- 每个batch的最大的token数量
max_sentence (int) -- 每个batch最多多少个instance, -1表示根据max_token决定
need_be_multiple_of (int) -- 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到
num_bucket (int) -- 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。
-