fastNLP.models.cnn_text_classification

class fastNLP.models.cnn_text_classification.CNNText(embed, num_classes, kernel_nums=30, 40, 50, kernel_sizes=1, 3, 5, dropout=0.5)[源代码]

别名 fastNLP.models.CNNText fastNLP.models.cnn_text_classification.CNNText

使用CNN进行文本分类的模型 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.'

__init__(embed, num_classes, kernel_nums=30, 40, 50, kernel_sizes=1, 3, 5, dropout=0.5)[源代码]
参数
  • embed (tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray) -- Embedding的大小(传入tuple(int, int), 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding

  • num_classes (int) -- 一共有多少类

  • kernel_sizes (int,tuple(int)) -- 输出channel的kernel大小。

  • dropout (float) -- Dropout的大小

forward(words, seq_len=None)[源代码]
参数
  • words (torch.LongTensor) -- [batch_size, seq_len],句子中word的index

  • seq_len (torch.LongTensor) -- [batch,] 每个句子的长度

Return output

dict of torch.LongTensor, [batch_size, num_classes]

predict(words, seq_len=None)[源代码]
参数
  • words (torch.LongTensor) -- [batch_size, seq_len],句子中word的index

  • seq_len (torch.LongTensor) -- [batch,] 每个句子的长度

Return predict

dict of torch.LongTensor, [batch_size, ]

training: bool