fastNLP.modules.encoder.transformer 源代码

r"""undocumented"""

__all__ = [
    "TransformerEncoder"
]
from torch import nn

from .attention import MultiHeadAttention


[文档]class TransformerEncoder(nn.Module): r""" transformer的encoder模块,不包含embedding层 """ class SubLayer(nn.Module): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): super(TransformerEncoder.SubLayer, self).__init__() self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) self.norm1 = nn.LayerNorm(model_size, eps=1e-6) self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), nn.ReLU(), nn.Dropout(dropout), nn.Linear(inner_size, model_size)) self.norm2 = nn.LayerNorm(model_size, eps=1e-6) self.dropout = nn.Dropout(dropout) def forward(self, input, seq_mask=None, atte_mask_out=None): r""" :param input: [batch, seq_len, model_size] :param seq_mask: [batch, seq_len] :return: [batch, seq_len, model_size] """ if seq_mask is None: # 防止后续乘法时出错 seq_mask = 1 input = self.norm1(input) attention = self.atte(input, input, input, atte_mask_out) input = input + self.dropout(attention) attention *= seq_mask input = self.norm2(input) output = self.ffn(input) input = input + self.dropout(output) input *= seq_mask return input
[文档] def __init__(self, num_layers, **kargs): r""" :param int num_layers: transformer的层数 :param int model_size: 输入维度的大小。同时也是输出维度的大小。 :param int inner_size: FFN层的hidden大小 :param int key_size: 每个head的维度大小。 :param int value_size: 每个head中value的维度。 :param int num_head: head的数量。 :param float dropout: dropout概率. Default: 0.1 """ super(TransformerEncoder, self).__init__() self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6)
[文档] def forward(self, x, seq_mask=None): r""" :param x: [batch, seq_len, model_size] 输入序列 :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. Default: ``None`` :return: [batch, seq_len, model_size] 输出序列 """ output = x if seq_mask is None: atte_mask_out = None else: atte_mask_out = (seq_mask.eq(False))[:, None, :] seq_mask = seq_mask[:, :, None] for layer in self.layers: output = layer(output, seq_mask, atte_mask_out) return self.norm(output)