r"""undocumented"""
from builtins import sorted
import torch
import numpy as np
from .field import _get_ele_type_and_dim
from .utils import logger
from copy import deepcopy
__all__ = ['ConcatCollateFn']
def _check_type(batch_dict, fields):
if len(fields) == 0:
raise RuntimeError
types = []
dims = []
for f in fields:
t, d = _get_ele_type_and_dim(batch_dict[f])
types.append(t)
dims.append(d)
diff_types = set(types)
diff_dims = set(dims)
if len(diff_types) > 1 or len(diff_dims) > 1:
raise ValueError
return types[0]
def batching(samples, max_len=0, padding_val=0):
if len(samples) == 0:
return samples
if max_len <= 0:
max_len = max(s.shape[0] for s in samples)
batch = np.full((len(samples), max_len), fill_value=padding_val)
for i, s in enumerate(samples):
slen = min(s.shape[0], max_len)
batch[i][:slen] = s[:slen]
return batch
class Collater:
r"""
辅助DataSet管理collate_fn的类
"""
def __init__(self):
self.collate_fns = {}
def add_fn(self, fn, name=None):
r"""
向collater新增一个collate_fn函数
:param callable fn:
:param str,int name:
:return:
"""
if name in self.collate_fns:
logger.warn(f"collate_fn:{name} will be overwritten.")
if name is None:
name = len(self.collate_fns)
self.collate_fns[name] = fn
def is_empty(self):
r"""
返回是否包含collate_fn
:return:
"""
return len(self.collate_fns) == 0
def delete_fn(self, name=None):
r"""
删除collate_fn
:param str,int name: 如果为None就删除最近加入的collate_fn
:return:
"""
if not self.is_empty():
if name in self.collate_fns:
self.collate_fns.pop(name)
elif name is None:
last_key = list(self.collate_fns.keys())[0]
self.collate_fns.pop(last_key)
def collate_batch(self, ins_list):
bx, by = {}, {}
for name, fn in self.collate_fns.items():
try:
batch_x, batch_y = fn(ins_list)
except BaseException as e:
logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.")
raise e
bx.update(batch_x)
by.update(batch_y)
return bx, by
def copy_from(self, col):
assert isinstance(col, Collater)
new_col = Collater()
new_col.collate_fns = deepcopy(col.collate_fns)
return new_col
[文档]class ConcatCollateFn:
r"""
field拼接collate_fn,将不同field按序拼接后,padding产生数据。
:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field
:param str output: 拼接后的field名称
:param pad_val: padding的数值
:param max_len: 拼接后最大长度
:param is_input: 是否将生成的output设置为input
:param is_target: 是否将生成的output设置为target
"""
def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False):
super().__init__()
assert isinstance(inputs, list)
self.inputs = inputs
self.output = output
self.pad_val = pad_val
self.max_len = max_len
self.is_input = is_input
self.is_target = is_target
@staticmethod
def _to_numpy(seq):
if torch.is_tensor(seq):
return seq.numpy()
else:
return np.array(seq)
def __call__(self, ins_list):
samples = []
for i, ins in ins_list:
sample = []
for input_name in self.inputs:
sample.append(self._to_numpy(ins[input_name]))
samples.append(np.concatenate(sample, axis=0))
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val)
b_x, b_y = {}, {}
if self.is_input:
b_x[self.output] = batch
if self.is_target:
b_y[self.output] = batch
return b_x, b_y