fastNLP.modules.decoder.utils 源代码

r"""undocumented"""

__all__ = [
    "viterbi_decode"
]
import torch


[文档]def viterbi_decode(logits, transitions, mask=None, unpad=False): r""" 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 :param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x (n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; 如果要i->j之间不允许越迁,就把transitions中(i,j)设置为很小的 负数,例如-10000000.0 :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 个sample的有效长度。 :return: 返回 (paths, scores)。 paths: 是解码后的路径, 其值参照unpad参数. scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 """ batch_size, seq_len, n_tags = logits.size() if transitions.size(0) == n_tags+2: include_start_end_trans = True elif transitions.size(0) == n_tags: include_start_end_trans = False else: raise RuntimeError("The shapes of transitions and feats are not " \ "compatible.") logits = logits.transpose(0, 1).data # L, B, H if mask is not None: mask = mask.transpose(0, 1).data.eq(True) # L, B else: mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8).eq(1) trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data # dp vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vscore = logits[0] if include_start_end_trans: vscore += transitions[n_tags, :n_tags] for i in range(1, seq_len): prev_score = vscore.view(batch_size, n_tags, 1) cur_score = logits[i].view(batch_size, 1, n_tags) score = prev_score + trans_score + cur_score best_score, best_dst = score.max(1) vpath[i] = best_dst vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ vscore.masked_fill(mask[i].view(batch_size, 1), 0) if include_start_end_trans: vscore += transitions[:n_tags, n_tags + 1].view(1, -1) # backtrace batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) lens = (mask.long().sum(0) - 1) # idxes [L, B], batched idx from seq_len-1 to 0 idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) ans_score, last_tags = vscore.max(1) ans[idxes[0], batch_idx] = last_tags for i in range(seq_len - 1): last_tags = vpath[idxes[i], batch_idx, last_tags] ans[idxes[i + 1], batch_idx] = last_tags ans = ans.transpose(0, 1) if unpad: paths = [] for idx, seq_len in enumerate(lens): paths.append(ans[idx, :seq_len + 1].tolist()) else: paths = ans return paths, ans_score