fastNLP.core.tester

tester模块实现了 fastNLP 所需的Tester类,能在提供数据、模型以及metric的情况下进行性能测试。

import numpy as np
import torch
from torch import nn
from fastNLP import Tester
from fastNLP import DataSet
from fastNLP import AccuracyMetric

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)
    def forward(self, a):
        return {'pred': self.fc(a.unsqueeze(1)).squeeze(1)}

model = Model()

dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2})

dataset.set_input('a')
dataset.set_target('b')

tester = Tester(dataset, model, metrics=AccuracyMetric())
eval_results = tester.test()

这里Metric的映射规律是和 fastNLP.Trainer 中一致的,具体使用请参考 trainer 模块 的1.3部分。 Tester在验证进行之前会调用model.eval()提示当前进入了evaluation阶段,即会关闭nn.Dropout()等,在验证结束之后会调用model.train()恢复到训练状态。

class fastNLP.core.tester.Tester(data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True, fp16=False, **kwargs)[源代码]

别名 fastNLP.Tester fastNLP.core.tester.Tester

Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。

__init__(data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True, fp16=False, **kwargs)[源代码]
参数
  • data (DataSet,BatchIter) -- 需要测试的数据集

  • model (torch.nn.Module) -- 使用的模型

  • metrics (MetricBase,List[MetricBase]) -- 测试时使用的metrics

  • batch_size (int) -- evaluation时使用的batch_size有多大。

  • device (str,int,torch.device,list(int)) --

    将模型load到哪个设备。默认为None,即Trainer不对模型 的计算位置进行管理。支持以下的输入:

    1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中,可见的第一个GPU中,可见的第二个GPU中;

    2. torch.device:将模型装载到torch.device上。

    3. int: 将使用device_id为该值的gpu进行训练

    4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。

    5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。

    如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。

  • verbose (int) -- 如果为0不输出任何信息; 如果为1,打印出验证结果。

  • use_tqdm (bool) -- 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。

  • fp16 (bool) -- 是否使用float16进行验证

  • kwargs -- Sampler sampler: 支持传入sampler控制测试顺序 bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。

test()[源代码]

开始进行验证,并返回验证结果。

Return Dict[Dict]

dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。