r"""
.. todo::
doc
"""
__all__ = [
'DataBundle',
]
from typing import Union, List
from ..core.dataset import DataSet
from ..core.vocabulary import Vocabulary
from ..core._logger import logger
[文档]class DataBundle:
r"""
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种
Loader的load函数生成,可以通过以下的方法获取里面的内容
Example::
data_bundle = YelpLoader().load({'train':'/path/to/train', 'dev': '/path/to/dev'})
train_vocabs = data_bundle.vocabs['train']
train_data = data_bundle.datasets['train']
dev_data = data_bundle.datasets['train']
"""
[文档] def __init__(self, vocabs: dict = None, datasets: dict = None):
r"""
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在
使用Pipe处理数据的时候遇到问题,若多个数据集确需一致,请手动deepcopy后传入。
"""
self.vocabs = vocabs or {}
self.datasets = datasets or {}
[文档] def set_vocab(self, vocab, field_name):
r"""
向DataBunlde中增加vocab
:param ~fastNLP.Vocabulary vocab: 词表
:param str field_name: 这个vocab对应的field名称
:return: self
"""
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports."
self.vocabs[field_name] = vocab
return self
[文档] def set_dataset(self, dataset, name: str):
r"""
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet
:param str name: dataset的名称
:return: self
"""
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet supports."
self.datasets[name] = dataset
return self
[文档] def get_dataset(self, name: str) -> DataSet:
r"""
获取名为name的dataset
:param str name: dataset的名称,一般为'train', 'dev', 'test'
:return: DataSet
"""
if name in self.datasets.keys():
return self.datasets[name]
else:
error_msg = f'DataBundle do NOT have DataSet named {name}. ' \
f'It should be one of {self.datasets.keys()}.'
logger.error(error_msg)
raise KeyError(error_msg)
[文档] def delete_dataset(self, name: str):
r"""
删除名为name的DataSet
:param str name:
:return: self
"""
self.datasets.pop(name, None)
return self
[文档] def get_vocab(self, field_name: str) -> Vocabulary:
r"""
获取field名为field_name对应的vocab
:param str field_name: 名称
:return: Vocabulary
"""
if field_name in self.vocabs.keys():
return self.vocabs[field_name]
else:
error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \
f'It should be one of {self.vocabs.keys()}.'
logger.error(error_msg)
raise KeyError(error_msg)
[文档] def delete_vocab(self, field_name: str):
r"""
删除vocab
:param str field_name:
:return: self
"""
self.vocabs.pop(field_name, None)
return self
@property
def num_dataset(self):
return len(self.datasets)
@property
def num_vocab(self):
return len(self.vocabs)
[文档] def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True):
r"""
将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作::
data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True
data_bundle.set_target('target', flag=False) # 将target这个field的input属性设置为False
:param str field_names: field的名称
:param bool flag: 将field_name的target状态设置为flag
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
行的数据进行类型和维度推断本列的数据的类型和维度。
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:return: self
"""
for field_name in field_names:
for name, dataset in self.datasets.items():
if not ignore_miss_dataset and not dataset.has_field(field_name):
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}")
if not dataset.has_field(field_name):
continue
else:
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)
return self
[文档] def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True):
r"""
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val.
:param str field_name:
:param int pad_val:
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:return: self
"""
for name, dataset in self.datasets.items():
if dataset.has_field(field_name=field_name):
dataset.set_pad_val(field_name=field_name, pad_val=pad_val)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name}.")
return self
[文档] def set_ignore_type(self, *field_names, flag=True, ignore_miss_dataset=True):
r"""
将DataBundle中所有的DataSet中名为*field_names的Field的ignore_type设置为flag状态
:param str field_names:
:param bool flag:
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:return: self
"""
for name, dataset in self.datasets.items():
for field_name in field_names:
if dataset.has_field(field_name=field_name):
dataset.set_ignore_type(field_name, flag=flag)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name}.")
return self
[文档] def copy_field(self, field_name, new_field_name, ignore_miss_dataset=True):
r"""
将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name.
:param str field_name:
:param str new_field_name:
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:return: self
"""
for name, dataset in self.datasets.items():
if dataset.has_field(field_name=field_name):
dataset.copy_field(field_name=field_name, new_field_name=new_field_name)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name}.")
return self
[文档] def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True, rename_vocab=True):
r"""
将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name.
:param str field_name:
:param str new_field_name:
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改
:return: self
"""
for name, dataset in self.datasets.items():
if dataset.has_field(field_name=field_name):
dataset.rename_field(field_name=field_name, new_field_name=new_field_name)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name}.")
if rename_vocab:
if field_name in self.vocabs:
self.vocabs[new_field_name] = self.vocabs.pop(field_name)
return self
[文档] def delete_field(self, field_name, ignore_miss_dataset=True, delete_vocab=True):
r"""
将DataBundle中所有DataSet中名为field_name的field删除掉.
:param str field_name:
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除
:return: self
"""
for name, dataset in self.datasets.items():
if dataset.has_field(field_name=field_name):
dataset.delete_field(field_name=field_name)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name}.")
if delete_vocab:
if field_name in self.vocabs:
self.vocabs.pop(field_name)
return self
[文档] def iter_datasets(self) -> Union[str, DataSet]:
r"""
迭代data_bundle中的DataSet
Example::
for name, dataset in data_bundle.iter_datasets():
pass
:return:
"""
for name, dataset in self.datasets.items():
yield name, dataset
[文档] def get_dataset_names(self) -> List[str]:
r"""
返回DataBundle中DataSet的名称
:return:
"""
return list(self.datasets.keys())
[文档] def get_vocab_names(self) -> List[str]:
r"""
返回DataBundle中Vocabulary的名称
:return:
"""
return list(self.vocabs.keys())
[文档] def iter_vocabs(self) -> Union[str, Vocabulary]:
r"""
迭代data_bundle中的DataSet
Example:
for field_name, vocab in data_bundle.iter_vocabs():
pass
:return:
"""
for field_name, vocab in self.vocabs.items():
yield field_name, vocab
[文档] def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法
:param callable func: input是instance中名为 `field_name` 的field的内容。
:param str field_name: 传入func的是哪个field。
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:param optional kwargs: 支持输入is_input,is_target,ignore_type
1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
4. use_tqdm: bool, 是否显示tqdm进度条
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
"""
tqdm_desc = kwargs.get('tqdm_desc', '')
for name, dataset in self.datasets.items():
if tqdm_desc != '':
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
if dataset.has_field(field_name=field_name):
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name}.")
return self
[文档] def apply_field_more(self, func, field_name, modify_fields=True, ignore_miss_dataset=True, **kwargs):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法
.. note::
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param str field_name: 传入func的是哪个field。
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:param optional kwargs: 支持输入is_input, is_target, ignore_type
1. is_input: bool, 如果为True则将被修改的field设置为input
2. is_target: bool, 如果为True则将被修改的field设置为target
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型
4. use_tqdm: bool, 是否显示tqdm进度条
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
"""
res = {}
tqdm_desc = kwargs.get('tqdm_desc', '')
for name, dataset in self.datasets.items():
if tqdm_desc != '':
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
if dataset.has_field(field_name=field_name):
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name} .")
return res
[文档] def apply(self, func, new_field_name: str, **kwargs):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法
对DataBundle中所有的dataset使用apply方法
:param callable func: input是instance中名为 `field_name` 的field的内容。
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。
:param optional kwargs: 支持输入is_input,is_target,ignore_type
1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
4. use_tqdm: bool, 是否显示tqdm进度条
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
"""
tqdm_desc = kwargs.get('tqdm_desc', '')
for name, dataset in self.datasets.items():
if tqdm_desc != '':
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
dataset.apply(func, new_field_name=new_field_name, **kwargs)
return self
[文档] def apply_more(self, func, modify_fields=True, **kwargs):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法
.. note::
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param optional kwargs: 支持输入is_input,is_target,ignore_type
1. is_input: bool, 如果为True则将被修改的的field设置为input
2. is_target: bool, 如果为True则将被修改的的field设置为target
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型
4. use_tqdm: bool, 是否显示tqdm进度条
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
"""
res = {}
tqdm_desc = kwargs.get('tqdm_desc', '')
for name, dataset in self.datasets.items():
if tqdm_desc!='':
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs)
return res
[文档] def add_collate_fn(self, fn, name=None):
r"""
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明.
:param callable fn:
:param name:
:return:
"""
for _, dataset in self.datasets.items():
dataset.add_collate_fn(fn=fn, name=name)
[文档] def delete_collate_fn(self, name=None):
r"""
删除DataSet中的collate_fn
:param name:
:return:
"""
for _, dataset in self.datasets.items():
dataset.delete_collate_fn(name=name)
def __repr__(self):
_str = ''
if len(self.datasets):
_str += 'In total {} datasets:\n'.format(self.num_dataset)
for name, dataset in self.datasets.items():
_str += '\t{} has {} instances.\n'.format(name, len(dataset))
if len(self.vocabs):
_str += 'In total {} vocabs:\n'.format(self.num_vocab)
for name, vocab in self.vocabs.items():
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
return _str