fastNLP.models.biaffine_parser¶
Biaffine Dependency Parser 的 Pytorch 实现.
-
class
fastNLP.models.biaffine_parser.
BiaffineParser
(embed, pos_vocab_size, pos_emb_dim, num_label, rnn_layers=1, rnn_hidden_size=200, arc_mlp_size=100, label_mlp_size=100, dropout=0.3, encoder='lstm', use_greedy_infer=False)[源代码]¶ -
别名
fastNLP.models.BiaffineParser
fastNLP.models.biaffine_parser.BiaffineParser
Biaffine Dependency Parser 实现. 论文参考 Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) .-
__init__
(embed, pos_vocab_size, pos_emb_dim, num_label, rnn_layers=1, rnn_hidden_size=200, arc_mlp_size=100, label_mlp_size=100, dropout=0.3, encoder='lstm', use_greedy_infer=False)[源代码]¶ 参数: - embed -- 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding
- pos_vocab_size -- part-of-speech 词典大小
- pos_emb_dim -- part-of-speech 向量维度
- num_label -- 边的类别个数
- rnn_layers -- rnn encoder的层数
- rnn_hidden_size -- rnn encoder 的隐状态维度
- arc_mlp_size -- 边预测的MLP维度
- label_mlp_size -- 类别预测的MLP维度
- dropout -- dropout概率.
- encoder -- encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm
- use_greedy_infer -- 是否在inference时使用贪心算法.
若
False
, 使用更加精确但相对缓慢的MST算法. Default:False
-
forward
(words1, words2, seq_len, target1=None)[源代码]¶ 模型forward阶段
参数: - words1 -- [batch_size, seq_len] 输入word序列
- words2 -- [batch_size, seq_len] 输入pos序列
- seq_len -- [batch_size, seq_len] 输入序列长度
- target1 -- [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效,
用于训练label分类器. 若为
None
, 使用预测的heads输入到label分类器 Default:None
Return dict: parsing 结果:
pred1: [batch_size, seq_len, seq_len] 边预测logits pred2: [batch_size, seq_len, num_label] label预测logits pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测
-
static
loss
(pred1, pred2, target1, target2, seq_len)[源代码]¶ 计算parser的loss
参数: - pred1 -- [batch_size, seq_len, seq_len] 边预测logits
- pred2 -- [batch_size, seq_len, num_label] label预测logits
- target1 -- [batch_size, seq_len] 真实边的标注
- target2 -- [batch_size, seq_len] 真实类别的标注
- seq_len -- [batch_size, seq_len] 真实目标的长度
Return loss: scalar
-
-
class
fastNLP.models.biaffine_parser.
GraphParser
[源代码]¶ 基类
fastNLP.models.BaseModel
别名
fastNLP.models.GraphParser
fastNLP.models.biaffine_parser.GraphParser
基于图的parser base class, 支持贪婪解码和最大生成树解码