fastNLP.core.sampler

sampler 子类实现了 fastNLP 所需的各种采样器。

class fastNLP.core.sampler.Sampler[源代码]

别名 fastNLP.Sampler fastNLP.core.sampler.Sampler

Sampler 类的基类. 规定以何种顺序取出data中的元素

子类必须实现 __call__ 方法. 输入 DataSet 对象, 返回其中元素的下标序列

__init__()

Initialize self. See help(type(self)) for accurate signature.

class fastNLP.core.sampler.BucketSampler(num_buckets=10, batch_size=None, seq_len_field_name='seq_len')[源代码]

基类 fastNLP.Sampler

别名 fastNLP.BucketSampler fastNLP.core.sampler.BucketSampler

带Bucket的 Random Sampler. 可以随机地取出长度相似的元素

__init__(num_buckets=10, batch_size=None, seq_len_field_name='seq_len')[源代码]
参数
  • num_buckets (int) -- bucket的数量

  • batch_size (int) -- batch的大小. 默认为None,Trainer/Tester在调用BucketSampler时,会将该值正确设置,如果是非 Trainer/Tester场景使用,需要显示传递该值

  • seq_len_field_name (str) -- 对应序列长度的 field 的名字

set_batch_size(batch_size)[源代码]
参数

batch_size (int) -- 每个batch的大小

返回

class fastNLP.core.sampler.SequentialSampler[源代码]

基类 fastNLP.Sampler

别名 fastNLP.SequentialSampler fastNLP.core.sampler.SequentialSampler

顺序取出元素的 Sampler

__init__()

Initialize self. See help(type(self)) for accurate signature.

class fastNLP.core.sampler.RandomSampler[源代码]

基类 fastNLP.Sampler

别名 fastNLP.RandomSampler fastNLP.core.sampler.RandomSampler

随机化取元素的 Sampler

__init__()

Initialize self. See help(type(self)) for accurate signature.

class fastNLP.core.sampler.SortedSampler(seq_len_field_name='seq_len', descending=True)[源代码]

基类 fastNLP.Sampler

别名 fastNLP.SortedSampler fastNLP.core.sampler.SortedSampler

按照sample的长度进行排序,主要在测试的时候使用,可以加速测试(因为减少了padding)

__init__(seq_len_field_name='seq_len', descending=True)[源代码]
参数
  • seq_len_field_name (str) -- 按哪个field进行排序。如果传入的field是数字,则直接按照该数字大小排序;如果传入的field不是 数字,则使用该field的长度进行排序

  • descending (bool) -- 是否降序排列

class fastNLP.core.sampler.ConstantTokenNumSampler(seq_len, max_token=4096, max_sentence=- 1, need_be_multiple_of=1, num_bucket=- 1)[源代码]

别名 fastNLP.ConstantTokenNumSampler fastNLP.core.sampler.ConstantTokenNumSampler

尽量保证每个batch的输入token数量是接近的。

使用示例 >>> # 假设已经有了tr_data并有一个field叫做seq_len保存了每个instance的token数量 >>> from fastNLP import DataSetIter, Trainer >>> sampler = ConstantTokenNumSampler(tr_data.get_field('seq_len').content, max_token=4096) >>> tr_iter = DataSetIter(tr_data, >>> batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, >>> drop_last=False, timeout=0, worker_init_fn=None, >>> batch_sampler=sampler) >>> >>> # 直接将tr_iter传入Trainer中,此时batch_size参数的值会被忽略 >>> trainer = Trainer(tr_iter, model, optimizer=optimizer, loss=TranslationLoss(), >>> batch_size=1, sampler=None, drop_last=False, update_every=1)

__init__(seq_len, max_token=4096, max_sentence=- 1, need_be_multiple_of=1, num_bucket=- 1)[源代码]
参数
  • seq_len (List[int]) -- list[int], 是每个sample的长度。一般可以通过dataset.get_field('seq_len').content传入

  • max_token (int) -- 每个batch的最大的token数量

  • max_sentence (int) -- 每个batch最多多少个instance, -1表示根据max_token决定

  • need_be_multiple_of (int) -- 生成的batch的instance的数量需要是几的倍数,在DataParallel场景下会用到

  • num_bucket (int) -- 将数据按长度拆分为num_bucket个bucket,batch中的sample尽量在bucket之中进行组合,这样可以减少padding。