fastNLP.io.pipe.matching 源代码

r"""undocumented"""

__all__ = [
    "MatchingBertPipe",
    "RTEBertPipe",
    "SNLIBertPipe",
    "QuoraBertPipe",
    "QNLIBertPipe",
    "MNLIBertPipe",
    "CNXNLIBertPipe",
    "BQCorpusBertPipe",
    "LCQMCBertPipe",
    "MatchingPipe",
    "RTEPipe",
    "SNLIPipe",
    "QuoraPipe",
    "QNLIPipe",
    "MNLIPipe",
    "LCQMCPipe",
    "CNXNLIPipe",
    "BQCorpusPipe",
    "RenamePipe",
    "GranularizePipe",
    "MachingTruncatePipe",
]

import warnings

from .pipe import Pipe
from .utils import get_tokenizer
from ..data_bundle import DataBundle
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, \
    LCQMCLoader
from ...core._logger import logger
from ...core.const import Const
from ...core.vocabulary import Vocabulary


[文档]class MatchingBertPipe(Pipe): r""" Matching任务的Bert pipe,输出的DataSet将包含以下的field .. csv-table:: :header: "raw_words1", "raw_words2", "target", "words", "seq_len" "The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", 10 "This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", 5 "...", "...", ., "[...]", . words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, 如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: +-------------+------------+------------+--------+-------+---------+ | field_names | raw_words1 | raw_words2 | target | words | seq_len | +-------------+------------+------------+--------+-------+---------+ | is_input | False | False | False | True | True | | is_target | False | False | True | False | False | | ignore_type | | | False | False | False | | pad_value | | | 0 | 0 | 0 | +-------------+------------+------------+--------+-------+---------+ """
[文档] def __init__(self, lower=False, tokenizer: str = 'raw'): r""" :param bool lower: 是否将word小写化。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 """ super().__init__() self.lower = bool(lower) self.tokenizer = get_tokenizer(tokenize_method=tokenizer)
def _tokenize(self, data_bundle, field_names, new_field_names): r""" :param DataBundle data_bundle: DataBundle. :param list field_names: List[str], 需要tokenize的field名称 :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 :return: 输入的DataBundle对象 """ for name, dataset in data_bundle.datasets.items(): for field_name, new_field_name in zip(field_names, new_field_names): dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, new_field_name=new_field_name) return data_bundle
[文档] def process(self, data_bundle): r""" 输入的data_bundle中的dataset需要具有以下结构: .. csv-table:: :header: "raw_words1", "raw_words2", "target" "Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" "...","..." :param data_bundle: :return: """ for dataset in data_bundle.datasets.values(): if dataset.has_field(Const.TARGET): dataset.drop(lambda x: x[Const.TARGET] == '-') for name, dataset in data_bundle.datasets.items(): dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), ) dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), ) if self.lower: for name, dataset in data_bundle.datasets.items(): dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(1)].lower() data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], [Const.INPUTS(0), Const.INPUTS(1)]) # concat两个words def concat(ins): words0 = ins[Const.INPUTS(0)] words1 = ins[Const.INPUTS(1)] words = words0 + ['[SEP]'] + words1 return words for name, dataset in data_bundle.datasets.items(): dataset.apply(concat, new_field_name=Const.INPUT) dataset.delete_field(Const.INPUTS(0)) dataset.delete_field(Const.INPUTS(1)) word_vocab = Vocabulary() word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], field_name=Const.INPUT, no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if 'train' not in name]) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) target_vocab = Vocabulary(padding=None, unknown=None) target_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], field_name=Const.TARGET, no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() if ('train' not in name) and (ds.has_field(Const.TARGET))] ) if len(target_vocab._no_create_word) > 0: warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f"data set but not in train data set!." warnings.warn(warn_msg) logger.warning(warn_msg) has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if dataset.has_field(Const.TARGET)] target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) data_bundle.set_vocab(word_vocab, Const.INPUT) data_bundle.set_vocab(target_vocab, Const.TARGET) input_fields = [Const.INPUT, Const.INPUT_LEN] target_fields = [Const.TARGET] for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) dataset.set_input(*input_fields, flag=True) for fields in target_fields: if dataset.has_field(fields): dataset.set_target(fields, flag=True) return data_bundle
class RTEBertPipe(MatchingBertPipe): def process_from_file(self, paths=None): data_bundle = RTELoader().load(paths) return self.process(data_bundle) class SNLIBertPipe(MatchingBertPipe): def process_from_file(self, paths=None): data_bundle = SNLILoader().load(paths) return self.process(data_bundle) class QuoraBertPipe(MatchingBertPipe): def process_from_file(self, paths): data_bundle = QuoraLoader().load(paths) return self.process(data_bundle) class QNLIBertPipe(MatchingBertPipe): def process_from_file(self, paths=None): data_bundle = QNLILoader().load(paths) return self.process(data_bundle) class MNLIBertPipe(MatchingBertPipe): def process_from_file(self, paths=None): data_bundle = MNLILoader().load(paths) return self.process(data_bundle)
[文档]class MatchingPipe(Pipe): r""" Matching任务的Pipe。输出的DataSet将包含以下的field .. csv-table:: :header: "raw_words1", "raw_words2", "target", "words1", "words2", "seq_len1", "seq_len2" "The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", "[10, 20, 6]", 10, 13 "This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", "[2, 7, ...]", 6, 7 "...", "...", ., "[...]", "[...]", ., . words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target 和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 的形参名进行传参)。 dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: +-------------+------------+------------+--------+--------+--------+----------+----------+ | field_names | raw_words1 | raw_words2 | target | words1 | words2 | seq_len1 | seq_len2 | +-------------+------------+------------+--------+--------+--------+----------+----------+ | is_input | False | False | False | True | True | True | True | | is_target | False | False | True | False | False | False | False | | ignore_type | | | False | False | False | False | False | | pad_value | | | 0 | 0 | 0 | 0 | 0 | +-------------+------------+------------+--------+--------+--------+----------+----------+ """
[文档] def __init__(self, lower=False, tokenizer: str = 'raw'): r""" :param bool lower: 是否将所有raw_words转为小写。 :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 """ super().__init__() self.lower = bool(lower) self.tokenizer = get_tokenizer(tokenize_method=tokenizer)
def _tokenize(self, data_bundle, field_names, new_field_names): r""" :param ~fastNLP.DataBundle data_bundle: DataBundle. :param list field_names: List[str], 需要tokenize的field名称 :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 :return: 输入的DataBundle对象 """ for name, dataset in data_bundle.datasets.items(): for field_name, new_field_name in zip(field_names, new_field_names): dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, new_field_name=new_field_name) return data_bundle
[文档] def process(self, data_bundle): r""" 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 .. csv-table:: :header: "raw_words1", "raw_words2", "target" "The new rights are...", "Everyone really likes..", "entailment" "This site includes a...", "The Government Executive...", "not_entailment" "...", "..." :param ~fastNLP.DataBundle data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容 :return: data_bundle """ data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], [Const.INPUTS(0), Const.INPUTS(1)]) for dataset in data_bundle.datasets.values(): if dataset.has_field(Const.TARGET): dataset.drop(lambda x: x[Const.TARGET] == '-') if self.lower: for name, dataset in data_bundle.datasets.items(): dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(1)].lower() word_vocab = Vocabulary() word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], field_name=[Const.INPUTS(0), Const.INPUTS(1)], no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if 'train' not in name]) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) target_vocab = Vocabulary(padding=None, unknown=None) target_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], field_name=Const.TARGET, no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() if ('train' not in name) and (ds.has_field(Const.TARGET))] ) if len(target_vocab._no_create_word) > 0: warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f"data set but not in train data set!." warnings.warn(warn_msg) logger.warning(warn_msg) has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if dataset.has_field(Const.TARGET)] target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) data_bundle.set_vocab(target_vocab, Const.TARGET) input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] target_fields = [Const.TARGET] for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) dataset.set_input(*input_fields, flag=True) for fields in target_fields: if dataset.has_field(fields): dataset.set_target(fields, flag=True) return data_bundle
class RTEPipe(MatchingPipe): def process_from_file(self, paths=None): data_bundle = RTELoader().load(paths) return self.process(data_bundle) class SNLIPipe(MatchingPipe): def process_from_file(self, paths=None): data_bundle = SNLILoader().load(paths) return self.process(data_bundle) class QuoraPipe(MatchingPipe): def process_from_file(self, paths): data_bundle = QuoraLoader().load(paths) return self.process(data_bundle) class QNLIPipe(MatchingPipe): def process_from_file(self, paths=None): data_bundle = QNLILoader().load(paths) return self.process(data_bundle) class MNLIPipe(MatchingPipe): def process_from_file(self, paths=None): data_bundle = MNLILoader().load(paths) return self.process(data_bundle) class LCQMCPipe(MatchingPipe): def __init__(self, tokenizer='cn=char'): super().__init__(tokenizer=tokenizer) def process_from_file(self, paths=None): data_bundle = LCQMCLoader().load(paths) data_bundle = RenamePipe().process(data_bundle) data_bundle = self.process(data_bundle) data_bundle = RenamePipe().process(data_bundle) return data_bundle class CNXNLIPipe(MatchingPipe): def __init__(self, tokenizer='cn-char'): super().__init__(tokenizer=tokenizer) def process_from_file(self, paths=None): data_bundle = CNXNLILoader().load(paths) data_bundle = GranularizePipe(task='XNLI').process(data_bundle) data_bundle = RenamePipe().process(data_bundle) # 使中文数据的field data_bundle = self.process(data_bundle) data_bundle = RenamePipe().process(data_bundle) return data_bundle class BQCorpusPipe(MatchingPipe): def __init__(self, tokenizer='cn-char'): super().__init__(tokenizer=tokenizer) def process_from_file(self, paths=None): data_bundle = BQCorpusLoader().load(paths) data_bundle = RenamePipe().process(data_bundle) data_bundle = self.process(data_bundle) data_bundle = RenamePipe().process(data_bundle) return data_bundle class RenamePipe(Pipe): def __init__(self, task='cn-nli'): super().__init__() self.task = task def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset if (self.task == 'cn-nli'): for name, dataset in data_bundle.datasets.items(): if (dataset.has_field(Const.RAW_CHARS(0))): dataset.rename_field(Const.RAW_CHARS(0), Const.RAW_WORDS(0)) # RAW_CHARS->RAW_WORDS dataset.rename_field(Const.RAW_CHARS(1), Const.RAW_WORDS(1)) elif (dataset.has_field(Const.INPUTS(0))): dataset.rename_field(Const.INPUTS(0), Const.CHAR_INPUTS(0)) # WORDS->CHARS dataset.rename_field(Const.INPUTS(1), Const.CHAR_INPUTS(1)) dataset.rename_field(Const.RAW_WORDS(0), Const.RAW_CHARS(0)) dataset.rename_field(Const.RAW_WORDS(1), Const.RAW_CHARS(1)) else: raise RuntimeError( "field name of dataset is not qualified. It should have ether RAW_CHARS or WORDS") elif (self.task == 'cn-nli-bert'): for name, dataset in data_bundle.datasets.items(): if (dataset.has_field(Const.RAW_CHARS(0))): dataset.rename_field(Const.RAW_CHARS(0), Const.RAW_WORDS(0)) # RAW_CHARS->RAW_WORDS dataset.rename_field(Const.RAW_CHARS(1), Const.RAW_WORDS(1)) elif (dataset.has_field(Const.RAW_WORDS(0))): dataset.rename_field(Const.RAW_WORDS(0), Const.RAW_CHARS(0)) dataset.rename_field(Const.RAW_WORDS(1), Const.RAW_CHARS(1)) dataset.rename_field(Const.INPUT, Const.CHAR_INPUT) else: raise RuntimeError( "field name of dataset is not qualified. It should have ether RAW_CHARS or RAW_WORDS" ) else: raise RuntimeError( "Only support task='cn-nli' or 'cn-nli-bert'" ) return data_bundle class GranularizePipe(Pipe): def __init__(self, task=None): super().__init__() self.task = task def _granularize(self, data_bundle, tag_map): r""" 该函数对data_bundle中'target'列中的内容进行转换。 :param data_bundle: :param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, 且将"1"认为是第0类。 :return: 传入的data_bundle """ for name in list(data_bundle.datasets.keys()): dataset = data_bundle.get_dataset(name) dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET, new_field_name=Const.TARGET) dataset.drop(lambda ins: ins[Const.TARGET] == -100) data_bundle.set_dataset(dataset, name) return data_bundle def process(self, data_bundle: DataBundle): task_tag_dict = { 'XNLI': {'neutral': 0, 'entailment': 1, 'contradictory': 2, 'contradiction': 2} } if self.task in task_tag_dict: data_bundle = self._granularize(data_bundle=data_bundle, tag_map=task_tag_dict[self.task]) else: raise RuntimeError(f"Only support {task_tag_dict.keys()} task_tag_map.") return data_bundle class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len def __init__(self): super().__init__() def process(self, data_bundle: DataBundle): for name, dataset in data_bundle.datasets.items(): pass return None class LCQMCBertPipe(MatchingBertPipe): def __init__(self, tokenizer='cn=char'): super().__init__(tokenizer=tokenizer) def process_from_file(self, paths=None): data_bundle = LCQMCLoader().load(paths) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) data_bundle = TruncateBertPipe(task='cn').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) return data_bundle class BQCorpusBertPipe(MatchingBertPipe): def __init__(self, tokenizer='cn-char'): super().__init__(tokenizer=tokenizer) def process_from_file(self, paths=None): data_bundle = BQCorpusLoader().load(paths) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) data_bundle = TruncateBertPipe(task='cn').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) return data_bundle class CNXNLIBertPipe(MatchingBertPipe): def __init__(self, tokenizer='cn-char'): super().__init__(tokenizer=tokenizer) def process_from_file(self, paths=None): data_bundle = CNXNLILoader().load(paths) data_bundle = GranularizePipe(task='XNLI').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) data_bundle = TruncateBertPipe(task='cn').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) return data_bundle class TruncateBertPipe(Pipe): def __init__(self, task='cn'): super().__init__() self.task = task def _truncate(self, sentence_index:list, sep_index_vocab): # 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index sep_index_words = sentence_index.index(sep_index_vocab) words_before_sep = sentence_index[:sep_index_words] words_after_sep = sentence_index[sep_index_words:] # 注意此部分包括了[SEP] if self.task == 'cn': # 中文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过250 words_before_sep = words_before_sep[:250] words_after_sep = words_after_sep[:250] elif self.task == 'en': # 英文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过215 words_before_sep = words_before_sep[:215] words_after_sep = words_after_sep[:215] else: raise RuntimeError("Only support 'cn' or 'en' task.") return words_before_sep + words_after_sep def process(self, data_bundle: DataBundle) -> DataBundle: for name in data_bundle.datasets.keys(): dataset = data_bundle.get_dataset(name) sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') dataset.apply_field(lambda sent_index: self._truncate(sentence_index=sent_index, sep_index_vocab=sep_index_vocab), field_name='words', new_field_name='words') # truncate之后需要更新seq_len dataset.add_seq_len(field_name='words') return data_bundle