fastNLP.models.cnn_text_classification¶
-
class
fastNLP.models.cnn_text_classification.
CNNText
(embed, num_classes, kernel_nums=30, 40, 50, kernel_sizes=1, 3, 5, dropout=0.5)[源代码]¶ 别名
fastNLP.models.CNNText
fastNLP.models.cnn_text_classification.CNNText
使用CNN进行文本分类的模型 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.'
-
__init__
(embed, num_classes, kernel_nums=30, 40, 50, kernel_sizes=1, 3, 5, dropout=0.5)[源代码]¶ - 参数
embed (tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray) -- Embedding的大小(传入tuple(int, int), 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
num_classes (int) -- 一共有多少类
kernel_sizes (int,tuple(int)) -- 输出channel的kernel大小。
dropout (float) -- Dropout的大小
-
forward
(words, seq_len=None)[源代码]¶ - 参数
words (torch.LongTensor) -- [batch_size, seq_len],句子中word的index
seq_len (torch.LongTensor) -- [batch,] 每个句子的长度
- Return output
dict of torch.LongTensor, [batch_size, num_classes]
-
predict
(words, seq_len=None)[源代码]¶ - 参数
words (torch.LongTensor) -- [batch_size, seq_len],句子中word的index
seq_len (torch.LongTensor) -- [batch,] 每个句子的长度
- Return predict
dict of torch.LongTensor, [batch_size, ]
-
training
: bool¶
-