fastNLP.models.cnn_text_classification

class fastNLP.models.cnn_text_classification.CNNText(embed, num_classes, kernel_nums=(30, 40, 50), kernel_sizes=(1, 3, 5), dropout=0.5)[源代码]

别名 fastNLP.models.CNNText fastNLP.models.cnn_text_classification.CNNText

使用CNN进行文本分类的模型 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.'

__init__(embed, num_classes, kernel_nums=(30, 40, 50), kernel_sizes=(1, 3, 5), dropout=0.5)[源代码]
参数:
  • embed (tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray) -- Embedding的大小(传入tuple(int, int), 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
  • num_classes (int) -- 一共有多少类
  • kernel_sizes (int,tuple(int)) -- 输出channel的kernel大小。
  • dropout (float) -- Dropout的大小
forward(words, seq_len=None)[源代码]
参数:
  • words (torch.LongTensor) -- [batch_size, seq_len],句子中word的index
  • seq_len (torch.LongTensor) -- [batch,] 每个句子的长度
Return output:

dict of torch.LongTensor, [batch_size, num_classes]

predict(words, seq_len=None)[源代码]
参数:
  • words (torch.LongTensor) -- [batch_size, seq_len],句子中word的index
  • seq_len (torch.LongTensor) -- [batch,] 每个句子的长度
Return predict:

dict of torch.LongTensor, [batch_size, ]