fastNLP.models.cnn_text_classification 源代码

r"""
.. todo::
    doc
"""

__all__ = [
    "CNNText"
]

import torch
import torch.nn as nn

from ..core.const import Const as C
from ..core.utils import seq_len_to_mask
from ..embeddings import embedding
from ..modules import encoder


[文档]class CNNText(torch.nn.Module): r""" 使用CNN进行文本分类的模型 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' """
[文档] def __init__(self, embed, num_classes, kernel_nums=(30, 40, 50), kernel_sizes=(1, 3, 5), dropout=0.5): r""" :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding :param int num_classes: 一共有多少类 :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 :param float dropout: Dropout的大小 """ super(CNNText, self).__init__() # no support for pre-trained embedding currently self.embed = embedding.Embedding(embed) self.conv_pool = encoder.ConvMaxpool( in_channels=self.embed.embedding_dim, out_channels=kernel_nums, kernel_sizes=kernel_sizes) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(sum(kernel_nums), num_classes)
[文档] def forward(self, words, seq_len=None): r""" :param torch.LongTensor words: [batch_size, seq_len],句子中word的index :param torch.LongTensor seq_len: [batch,] 每个句子的长度 :return output: dict of torch.LongTensor, [batch_size, num_classes] """ x = self.embed(words) # [N,L] -> [N,L,C] if seq_len is not None: mask = seq_len_to_mask(seq_len) x = self.conv_pool(x, mask) else: x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.dropout(x) x = self.fc(x) # [N,C] -> [N, N_class] return {C.OUTPUT: x}
[文档] def predict(self, words, seq_len=None): r""" :param torch.LongTensor words: [batch_size, seq_len],句子中word的index :param torch.LongTensor seq_len: [batch,] 每个句子的长度 :return predict: dict of torch.LongTensor, [batch_size, ] """ output = self(words, seq_len) _, predict = output[C.OUTPUT].max(dim=1) return {C.OUTPUT: predict}