fastNLP.io.pipe.qa 源代码

r"""
本文件中的Pipe主要用于处理问答任务的数据。

"""


from copy import deepcopy

from .pipe import Pipe
from .. import DataBundle
from ..loader.qa import CMRC2018Loader
from .utils import get_tokenizer
from ...core import DataSet
from ...core import Vocabulary

__all__ = ['CMRC2018BertPipe']


def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'):
    r"""
    处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。

    会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start
    与target_end是与raw_chars等长的。其中target_start和target_end是前闭后闭的区间。

    :param DataBundle data_bundle: 类似["a", "b", "[SEP]", "c", ]
    :return:
    """
    tokenizer = get_tokenizer('cn-char', lang='cn')
    for name in list(data_bundle.datasets.keys()):
        ds = data_bundle.get_dataset(name)
        data_bundle.delete_dataset(name)
        new_ds = DataSet()
        for ins in ds:
            new_ins = deepcopy(ins)
            context = ins['context']
            question = ins['question']

            cnt_lst = tokenizer(context)
            q_lst = tokenizer(question)

            answer_start = -1

            if len(cnt_lst) + len(q_lst) + 3 > max_len:  # 预留开头的[CLS]和[SEP]和中间的[sep]
                if 'answer_starts' in ins and 'answers' in ins:
                    answer_start = int(ins['answer_starts'][0])
                    answer = ins['answers'][0]
                    answer_end = answer_start + len(answer)
                    if answer_end > max_len - 3 - len(q_lst):
                        span_start = answer_end + 3 + len(q_lst) - max_len
                        span_end = answer_end
                    else:
                        span_start = 0
                        span_end = max_len - 3 - len(q_lst)
                    cnt_lst = cnt_lst[span_start:span_end]
                    answer_start = int(ins['answer_starts'][0])
                    answer_start -= span_start
                    answer_end = answer_start + len(ins['answers'][0])
                else:
                    cnt_lst = cnt_lst[:max_len - len(q_lst) - 3]
            else:
                if 'answer_starts' in ins and 'answers' in ins:
                    answer_start = int(ins['answer_starts'][0])
                    answer_end = answer_start + len(ins['answers'][0])

            tokens = cnt_lst + ['[SEP]'] + q_lst
            new_ins['context_len'] = len(cnt_lst)
            new_ins[concat_field_name] = tokens

            if answer_start != -1:
                new_ins['target_start'] = answer_start
                new_ins['target_end'] = answer_end - 1

            new_ds.append(new_ins)
        data_bundle.set_dataset(new_ds, name)

    return data_bundle


[文档]class CMRC2018BertPipe(Pipe): r""" 处理之后的DataSet将新增以下的field(传入的field仍然保留) .. csv-table:: :header: "context_len", "raw_chars", "target_start", "target_end", "chars" 492, ['范', '廷', '颂... ], 30, 34, "[21, 25, ...]" 491, ['范', '廷', '颂... ], 41, 61, "[21, 25, ...]" ".", "...", "...","...", "..." raw_words列是context与question拼起来的结果(连接的地方加入了[SEP]),words是转为index的值, target_start为答案start的index,target_end为答案end的index (闭区间);context_len指示的是words列中context的长度。 其中各列的meta信息如下: .. code:: +-------------+-------------+-----------+--------------+------------+-------+---------+ | field_names | context_len | raw_chars | target_start | target_end | chars | answers | +-------------+-------------+-----------+--------------+------------+-------+---------| | is_input | False | False | False | False | True | False | | is_target | True | True | True | True | False | True | | ignore_type | False | True | False | False | False | True | | pad_value | 0 | 0 | 0 | 0 | 0 | 0 | +-------------+-------------+-----------+--------------+------------+-------+---------+ """ def __init__(self, max_len=510): super().__init__() self.max_len = max_len
[文档] def process(self, data_bundle: DataBundle) -> DataBundle: r""" 传入的DataSet应该具备以下的field .. csv-table:: :header:"title", "context", "question", "answers", "answer_starts", "id" "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" "...", "...", "...","...", ".", "..." :param data_bundle: :return: """ data_bundle = _concat_clip(data_bundle, max_len=self.max_len, concat_field_name='raw_chars') src_vocab = Vocabulary() src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], field_name='raw_chars', no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() if 'train' not in name] ) src_vocab.index_dataset(*data_bundle.datasets.values(), field_name='raw_chars', new_field_name='chars') data_bundle.set_vocab(src_vocab, 'chars') data_bundle.set_ignore_type('raw_chars', 'answers', flag=True) data_bundle.set_input('chars') data_bundle.set_target('raw_chars', 'answers', 'target_start', 'target_end', 'context_len') return data_bundle
def process_from_file(self, paths=None) -> DataBundle: data_bundle = CMRC2018Loader().load(paths) return self.process(data_bundle)