r"""
.. todo::
doc
"""
__all__ = [
"StackEmbedding",
]
from typing import List
import torch
from torch import nn as nn
from .embedding import TokenEmbedding
from .utils import _check_vocab_has_same_index
[文档]class StackEmbedding(TokenEmbedding):
r"""
支持将多个embedding集合成一个embedding。
Example::
>>> from fastNLP import Vocabulary
>>> from fastNLP.embeddings import StaticEmbedding, StackEmbedding
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
>>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True)
>>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
>>> embed = StackEmbedding([embed_1, embed_2])
"""
[文档] def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0):
r"""
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置
被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
"""
vocabs = []
for embed in embeds:
if hasattr(embed, 'get_word_vocab'):
vocabs.append(embed.get_word_vocab())
_vocab = vocabs[0]
for vocab in vocabs[1:]:
if _vocab!=vocab:
_check_vocab_has_same_index(_vocab, vocab)
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
assert isinstance(embeds, list)
for embed in embeds:
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
self.embeds = nn.ModuleList(embeds)
self._embed_size = sum([embed.embed_size for embed in self.embeds])
[文档] def append(self, embed: TokenEmbedding):
r"""
添加一个embedding到结尾。
:param embed:
:return:
"""
assert isinstance(embed, TokenEmbedding)
_check_vocab_has_same_index(self.get_word_vocab(), embed.get_word_vocab())
self._embed_size += embed.embed_size
self.embeds.append(embed)
return self
[文档] def pop(self):
r"""
弹出最后一个embed
:return:
"""
embed = self.embeds.pop()
self._embed_size -= embed.embed_size
return embed
@property
def embed_size(self):
r"""
该Embedding输出的vector的最后一维的维度。
:return:
"""
return self._embed_size
[文档] def forward(self, words):
r"""
得到多个embedding的结果,并把结果按照顺序concat起来。
:param words: batch_size x max_len
:return: 返回的shape和当前这个stack embedding中embedding的组成有关
"""
outputs = []
words = self.drop_word(words)
for embed in self.embeds:
outputs.append(embed(words))
outputs = self.dropout(torch.cat(outputs, dim=-1))
return outputs