r"""undocumented"""
__all__ = [
"CWSPipe"
]
import re
from itertools import chain
from .pipe import Pipe
from .utils import _indexize
from .. import DataBundle
from ..loader import CWSLoader
from ...core.const import Const
def _word_lens_to_bmes(word_lens):
r"""
:param list word_lens: List[int], 每个词语的长度
:return: List[str], BMES的序列
"""
tags = []
for word_len in word_lens:
if word_len == 1:
tags.append('S')
else:
tags.append('B')
tags.extend(['M'] * (word_len - 2))
tags.append('E')
return tags
def _word_lens_to_segapp(word_lens):
r"""
:param list word_lens: List[int], 每个词语的长度
:return: List[str], BMES的序列
"""
tags = []
for word_len in word_lens:
if word_len == 1:
tags.append('SEG')
else:
tags.extend(['APP'] * (word_len - 1))
tags.append('SEG')
return tags
def _alpha_span_to_special_tag(span):
r"""
将span替换成特殊的字符
:param str span:
:return:
"""
if 'oo' == span.lower(): # speical case when represent 2OO8
return span
if len(span) == 1:
return span
else:
return '<ENG>'
def _find_and_replace_alpha_spans(line):
r"""
传入原始句子,替换其中的字母为特殊标记
:param str line:原始数据
:return: str
"""
new_line = ''
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%,.。!<-“])'
prev_end = 0
for match in re.finditer(pattern, line):
start, end = match.span()
span = line[start:end]
new_line += line[prev_end:start] + _alpha_span_to_special_tag(span)
prev_end = end
new_line += line[prev_end:]
return new_line
def _digit_span_to_special_tag(span):
r"""
:param str span: 需要替换的str
:return:
"""
if span[0] == '0' and len(span) > 2:
return '<NUM>'
decimal_point_count = 0 # one might have more than one decimal pointers
for idx, char in enumerate(span):
if char == '.' or char == '﹒' or char == '·':
decimal_point_count += 1
if span[-1] == '.' or span[-1] == '﹒' or span[
-1] == '·': # last digit being decimal point means this is not a number
if decimal_point_count == 1:
return span
else:
return '<UNKDGT>'
if decimal_point_count == 1:
return '<DEC>'
elif decimal_point_count > 1:
return '<UNKDGT>'
else:
return '<NUM>'
def _find_and_replace_digit_spans(line):
r"""
only consider words start with number, contains '.', characters.
If ends with space, will be processed
If ends with Chinese character, will be processed
If ends with or contains english char, not handled.
floats are replaced by <DEC>
otherwise unkdgt
"""
new_line = ''
pattern = r'\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%%,。!<-“])'
prev_end = 0
for match in re.finditer(pattern, line):
start, end = match.span()
span = line[start:end]
new_line += line[prev_end:start] + _digit_span_to_special_tag(span)
prev_end = end
new_line += line[prev_end:]
return new_line
[文档]class CWSPipe(Pipe):
r"""
对CWS数据进行预处理, 处理之后的数据,具备以下的结构
.. csv-table::
:header: "raw_words", "chars", "target", "seq_len"
"共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", 13
"2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", 20
"...", "[...]","[...]", .
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为::
+-------------+-----------+-------+--------+---------+
| field_names | raw_words | chars | target | seq_len |
+-------------+-----------+-------+--------+---------+
| is_input | False | True | True | True |
| is_target | False | False | True | True |
| ignore_type | | False | False | False |
| pad_value | | 0 | 0 | 0 |
+-------------+-----------+-------+--------+---------+
"""
[文档] def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False):
r"""
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None
:param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp
的tag为[seg, app, seg, app, app, app, seg, ...]
:param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。
:param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]
:param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]
"""
if encoding_type == 'bmes':
self.word_lens_to_tags = _word_lens_to_bmes
else:
self.word_lens_to_tags = _word_lens_to_segapp
self.dataset_name = dataset_name
self.bigrams = bigrams
self.trigrams = trigrams
self.replace_num_alpha = replace_num_alpha
def _tokenize(self, data_bundle):
r"""
将data_bundle中的'chars'列切分成一个一个的word.
例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ]
:param data_bundle:
:return:
"""
def split_word_into_chars(raw_chars):
words = raw_chars.split()
chars = []
for word in words:
char = []
subchar = []
for c in word:
if c == '<':
if subchar:
char.extend(subchar)
subchar = []
subchar.append(c)
continue
if c == '>' and len(subchar)>0 and subchar[0] == '<':
subchar.append(c)
char.append(''.join(subchar))
subchar = []
continue
if subchar:
subchar.append(c)
else:
char.append(c)
char.extend(subchar)
chars.append(char)
return chars
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT,
new_field_name=Const.CHAR_INPUT)
return data_bundle
[文档] def process(self, data_bundle: DataBundle) -> DataBundle:
r"""
可以处理的DataSet需要包含raw_words列
.. csv-table::
:header: "raw_words"
"上海 浦东 开发 与 法制 建设 同步"
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )"
"..."
:param data_bundle:
:return:
"""
data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT)
if self.replace_num_alpha:
data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
self._tokenize(data_bundle)
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT,
new_field_name=Const.TARGET)
dataset.apply_field(lambda chars: list(chain(*chars)), field_name=Const.CHAR_INPUT,
new_field_name=Const.CHAR_INPUT)
input_field_names = [Const.CHAR_INPUT]
if self.bigrams:
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])],
field_name=Const.CHAR_INPUT, new_field_name='bigrams')
input_field_names.append('bigrams')
if self.trigrams:
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams')
_indexize(data_bundle, input_field_names, Const.TARGET)
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.CHAR_INPUT)
data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)
return data_bundle
[文档] def process_from_file(self, paths=None) -> DataBundle:
r"""
:param str paths:
:return:
"""
if self.dataset_name is None and paths is None:
raise RuntimeError(
"You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.")
if self.dataset_name is not None and paths is not None:
raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously")
data_bundle = CWSLoader(self.dataset_name).load(paths)
return self.process(data_bundle)