fastNLP.core.trainer 源代码

r"""
Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰以下步骤的代码

    (1) epoch循环;
    
    (2) 将数据分成不同的Batch;
    
    (3) 对Batch进行pad;
    
    (4) 每个epoch结束或一定step后进行验证集验证;
    
    (5) 保存获得更好验证性能的模型。


----------------------------
1. Trainer的基本使用
----------------------------

下面的例子是使用神经网络来进行预测一个序列中是否有偶数个1。

.. code-block:: python

    import numpy as np
    from torch import nn
    import torch
    import torch.nn.functional as F
    from torch.optim import SGD

    from fastNLP import DataSet
    from fastNLP import Trainer
    from fastNLP import CrossEntropyLoss
    from fastNLP import AccuracyMetric
    from fastNLP.modules.decoder import MLP

    # 模型
    class Model(nn.Module):
        def __init__(self, input_num):
            super().__init__()
            self.fcs = MLP([input_num, 40, 40, 2], 'relu')

        def forward(self, x):
            x = self.fcs(x)
            return {'pred': x}
    model = Model(10)

    # 生成数据
    def generate_psedo_dataset(num_samples):
        dataset = DataSet()
        data = np.random.randint(2, size=(num_samples, 10))
        label = np.sum(data, axis=1)%2
        dataset = DataSet({'x':data.astype(float), 'label': label})
        dataset.set_input('x')
        dataset.set_target('label')
        return dataset
    tr_dataset = generate_psedo_dataset(1000)
    dev_data = generate_psedo_dataset(100)

    # 训练
    trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'),
                       optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000,
                       dev_data = dev_data, metrics=AccuracyMetric(target='label'))
    trainer.train()

由上面的例子可以看出通过使用Trainer,可以使得训练部分的代码大幅减少。
使用Trainer需要满足以下几个条件:

1.1 模型
----------------------------

1 模型的forward()的参数名需要与DataSet中的名字对应。实际上fastNLP在将DataSet中的数据传递给模型forward()时,是
通过匹配名称实现的。所以上例中,如果Model的forward函数修改为forward(self, data), 则DataSet中的'x'这个field就应该
改名为'data'。

2 传递给forward()的参数是DataSet中被设置为input的那些field。但如果forward()中没有对应的参数,则不会将数据传递
给forward()。例如,DataSet中'x1', 'x2'都是input,但是模型的函数为forward(self, x1), 那么'x2'不会传递给forward()。

3 模型的forward()返回值需要为一个dict。

1.2 Loss
----------------------------

fastNLP中的为了不限制forward函数的返回内容数量(比如一些复杂任务需要返回多个内容,如Dependency Parsing,
:mod:`Loss<fastNLP.core.losses>` 与 :mod:`Metric<fastNLP.core.metrics>` 都使用了通过名称来匹配相应内容的策略。如上面的例子中

.. code-block:: python

    trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'),
               optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000,
               dev_data = dev_data, metrics=AccuracyMetric(target='label'))

loss被设置为了 :class:`~fastNLP.CrossEntropyLoss` , 但在初始化的时候传入了target='label'这个参数,
:class:`~fastNLP.CrossEntropyLoss` 的初始化参数为(pred=None, target=None, padding_idx=-100)。

这里的两个参数分别为计算CrossEntropy时需要使用到的模型的预测值与真实值。
其中 `pred` 一般来自于模型forward()的返回结果,`target` 一般是来自于DataSet中被设置为target的field。
由于每个人对真实值或者model的返回值取名并不一样,所以fastNLP的 :mod:`Loss<fastNLP.core.losses>` 提供一种类似于映射的机制来匹配对应的值,
比如这里 :class:`~fastNLP.CrossEntropyLoss` 将尝试找到名为'label'的内容来作为真实值得到loss;
而pred=None, 则 :class:`~fastNLP.CrossEntropyLoss` 使用'pred'作为名称匹配预测值,
正好forward的返回值也叫pred,所以这里不需要申明pred。

尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。
fastNLP中提供了 :class:`~fastNLP.LossInForward` 这个loss。
这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,并使用它作为loss。
如果Trainer初始化没有提供loss则默认使用 :class:`~fastNLP.LossInForward` 。

.. todo::
    补充一个例子  详细例子可以参照

1.3 Metric
----------------------------

:mod:`Metric<fastNLP.core.metrics>` 使用了与上述Loss一样的策略,即使用名称进行匹配。
AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。

在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法,
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,
传入到predict()的参数也是从DataSet中被设置为input的field中选择出来的;
与forward()一样,返回值需要为一个dict。

.. todo::
    补充一个例子 具体例子可以参考
    
----------------------------
2. Trainer的代码检查
----------------------------

由于在fastNLP中采取了映射的机制,所以难免可能存在对应出错的情况。Trainer提供一种映射检查机制,可以通过check_code_level来进行控制
比如下面的例子中,由于各种原因产生的报错

Example2.1
----------------------------

.. code-block:: python

    import numpy as np
    from torch import nn
    import torch
    from torch.optim import SGD
    from fastNLP import Trainer
    from fastNLP import DataSet

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Linear(1, 1)
        def forward(self, x, b):
            loss = torch.mean((self.fc(x)-b)**2)
            return {'loss': loss}
    model = Model()

    dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
    dataset.set_input('a', 'b')

    trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001))

    trainer = Trainer(dataset, model, SGD(model.parameters()))
    #  会报以下的错误
    # input fields after batch(if batch size is 2):
    #     a: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
    #     b: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
    # There is no target field.
    # ....
    # NameError:
    # Problems occurred when calling Model.forward(self, x, b)
    #     missing param: ['x']
    #     unused field: ['a']
    #     Suggestion: You need to provide ['x'] in DataSet and set it as input.

这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里有两类
信息可以为你提供参考

1 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里
因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。

2 NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断
出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。

下面的例子是由于loss计算的时候找不到需要的值

Example2.2
----------------------------

.. code-block:: python

    import numpy as np
    from torch import nn
    from torch.optim import SGD
    from fastNLP import Trainer
    from fastNLP import DataSet
    from fastNLP import L1Loss
    import torch

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Linear(1, 1)
        def forward(self, a):
            return {'pred_b': self.fc(a.unsqueeze(1)).squeeze(1), 'No use':1}

    model = Model()

    dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2})

    dataset.set_input('a')
    dataset.set_target('b')

    trainer = Trainer(dataset, model, loss=L1Loss(target='label'), optimizer=SGD(model.parameters(), lr=0.001))
    # 报错信息如下
    # input fields after batch(if batch size is 2):
    #     a: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
    # target fields after batch(if batch size is 2):
    #     b: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
    # ....
    # NameError:
    # Problems occurred when calling L1Loss.get_loss(self, pred, target)
    #     missing param: ['pred(assign to `pred` in `L1Loss`)', 'label(assign to `target` in `L1Loss`)']
    #     unused field: ['b']
    #     unused param: ['pred_b', 'No use']
    #     target field: ['b']
    #     param from Model.forward(self, a): ['pred_b', 'No use']
    #     Suggestion: (1). Check key assignment for `target` when initialize L1Loss. Or provide `label` in DataSet or output of Model.forward(self, a).
    #             (2). Check key assignment for `pred` when initialize L1Loss. Or provide `pred` in DataSet or output of Model.forward(self, a).

报错信息也包含两部分:

1 第一部分与上面是一样的

2 这里报错的原因是由于计算loss的时候找不到相应的值(通过L1Loss.get_loss(self, pred, target)判断出来的);
报错的原因是因为 `pred` 和 `label` (我们在初始化L1Loss时将target指定为了label)都没有找到。
这里'unused field'是DataSet中出现了,但却没有被设置为input或者target的field;
'unused param'是forward()中返回且没有被使用到的内容;'target field'是被设置为了target的field;
'param from Model.forward(self, a)'是forward()返回的所有key。"Suggestion"是关于当前错误处理的建议。

但是在一些情况下,比如forward()返回值只有一个,target也只有一个,fastNLP不会进行匹配,而直接将forward()的结果作为pred,
将DataSet中的target设置为target。上面的例子在返回值中加入了一个'No use'则只是为了使得Loss去匹配结果。


下面是带有dev dataset时如果出现错误会发生的报错,

Example2.3
----------------------------

.. code-block:: python

    import numpy as np
    from torch import nn
    from torch.optim import SGD
    from fastNLP import Trainer
    from fastNLP import DataSet
    from fastNLP import AccuracyMetric
    import torch

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Linear(1, 1)
        def forward(self, a, b):
            loss = torch.mean((self.fc(a.float().unsqueeze(1))-b.float())**2)
            return {'loss': loss}
        def predict(self, a):  # 使用predict()进行验证
            return {'output':self.fc(a.float().unsqueeze(1))} #这里return的值不包含'pred'这个key
    model = Model()

    dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
    dev_data = DataSet({'a': np.arange(10, 20), 'b':np.arange(10, 20)*2})

    dataset.set_input('a', 'b')
    dev_data.set_input('a')  # 这里没有设置target

    trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001),
                     dev_data=dev_data, metrics=AccuracyMetric())

    # 报错信息
    # ...
    # NameError:
    # Problems occurred when calling AccuracyMetric.evaluate(self, pred, target, seq_len=None)
    #     missing param: ['pred(assign to `pred` in `AccuracyMetric`)', 'target(assign to `target` in `AccuracyMetric`)']
    #     unused param: ['output']
    #     target field: []
    #     param from Model.predict(self, a): ['output']
    #     Suggestion: (1). Check key assignment for `pred` when initialize AccuracyMetric. Or provide `pred` in DataSet or output of Model.predict(self, a).
    #             (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a).

报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')。

可以通过check_code_level调节检查的强度。默认为0,即进行检查。

----------------------------
3. Trainer与callback
----------------------------

虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。
为了解决这个问题fastNLP引入了callback的机制,:class:`~fastNLP.Callback` 是一种在Trainer训练过程中特定阶段会运行的函数集合,
所有的 :class:`~fastNLP.Callback` 都具有on_*(比如on_train_start, on_backward_begin)等函数。
如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用,例如::

    from fastNLP import Callback, EarlyStopCallback, Trainer, CrossEntropyLoss, AccuracyMetric
    from fastNLP.models import CNNText

    start_time = time.time()
    
    class MyCallback(Callback):
        def on_epoch_end(self):
            print('{:d}ms\n\n'.format(round((time.time()-start_time)*1000)))
    
    model = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1)
    trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, loss=CrossEntropyLoss(),
                      metrics=AccuracyMetric(), callbacks=[MyCallback(),EarlyStopCallback(10)])
    trainer.train()
    
这里,我们通过继承 :class:`~fastNLP.Callback` 类定义了自己的 callback 的,并和内置的 :class:`~fastNLP.EarlyStopCallback`
一起传给了 :class:`~fastNLP.Trainer` ,增强了 :class:`~fastNLP.Trainer` 的功能

fastNLP已经自带了很多callback函数供使用,可以参考 :mod:`fastNLP.core.callback` 。

"""
__all__ = [
    "Trainer"
]

import os
import time
from datetime import datetime, timedelta

import numpy as np
import torch
import torch.nn as nn

try:
    from tqdm.auto import tqdm
except:
    from .utils import _pseudo_tqdm as tqdm
import warnings
from pkg_resources import parse_version

from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException, Callback
from .dataset import DataSet
from .losses import _prepare_losser
from .metrics import _prepare_metrics
from .optimizer import Optimizer
from .sampler import Sampler
from .sampler import RandomSampler, ConstTokenNumSampler
from .tester import Tester
from .utils import _CheckError
from .utils import _build_args
from .utils import _check_forward_error
from .utils import _check_loss_evaluate
from .utils import _move_dict_value_to_device
from .utils import _get_func_signature
from .utils import _get_model_device
from .utils import _move_model_to_device
from .utils import _build_fp16_env
from .utils import _can_use_fp16
from ._parallel_utils import _model_contains_inner_module
from ._logger import logger


[文档]class Trainer(object): r""" Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰写 (1) epoch循环; (2) 将数据分成不同的Batch; (3) 对Batch进行pad; (4) 每个epoch结束或一定step后进行验证集验证; (5) 保存获得更好验证性能的模型等。 详细的介绍参见 :mod:`fastNLP.core.trainer` """
[文档] def __init__(self, train_data, model, optimizer=None, loss=None, batch_size=32, sampler=None, drop_last=False, update_every=1, num_workers=0, n_epochs=10, print_every=5, dev_data=None, metrics=None, metric_key=None, validate_every=-1, save_path=None, use_tqdm=True, device=None, callbacks=None, check_code_level=0, fp16=False, **kwargs): r""" :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型或 :class:`~fastNLP.BatchIter` 的子类 :param nn.modules model: 待训练的模型 :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 :param int batch_size: 训练和验证的时候的batch大小。 :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` :param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch :param num_workers: int, 有多少个线程来进行数据pad处理。 :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 :param int n_epochs: 需要优化迭代多少次。 :param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 :param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 :param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , 也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, 则保存当前模型。Metric种类详见 :mod:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 :param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, 比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 的计算位置进行管理。支持以下的输入: 1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中; 2. torch.device:将模型装载到torch.device上。 3. int: 将使用device_id为该值的gpu进行训练 4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 已知可能会出现的问题:Adagrad优化器可能无法正常使用这个参数,请手动管理模型位置。 :param list(callbacks) callbacks: 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 通过callback机制实现。 可使用的callback参见 :mod:`callback模块 <fastNLP.core.callback>` :param int check_code_level: 模型检查等级. -1: 不进行检查; 0: 仅出现错误时停止; 1: 如果有field没有被使用, 报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; (2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 :param bool fp16: 是否使用fp16进行训练。 :param kwargs: 支持配置可选参数 bool test_use_tqdm: 在dev上验证的时候是否开启tqdm Sampler test_sampler: 在evaluate的时候使用的sampler bool test_use_fp16: evalute的时候是否使用fp16测试,默认与fp16相同的取值。 bool set_grad_to_none: 在zero_grad的时候是否将gradient设置为None,而不是设置为zero GradScaler grad_scaler: 仅在fp16为True时有效,如果不使用torch.cuda.amp.GradScaler的初始化参数,可传入一个已经初始化后的 grad_scaler。 bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快数据速度。 """ super(Trainer, self).__init__() if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") # check metrics and dev_data if (not metrics) and dev_data is not None: raise ValueError("No metric for dev_data evaluation.") if metrics and (dev_data is None): raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") # check update every assert update_every >= 1, "update_every must be no less than 1." self.update_every = int(update_every) # check save_path if not (save_path is None or isinstance(save_path, str)): raise ValueError("save_path can only be None or `str`.") # prepare evaluate metrics = _prepare_metrics(metrics) # parse metric_key # increase_better is True. It means the exp result gets better if the indicator increases. # It is true by default. self.increase_better = True if metric_key is not None: self.increase_better = False if metric_key[0] == "-" else True self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key else: self.metric_key = None # prepare loss losser = _prepare_losser(loss) if isinstance(train_data, BatchIter): if sampler is not None: warnings.warn("sampler is ignored when train_data is a BatchIter.") if num_workers>0: warnings.warn("num_workers is ignored when train_data is BatchIter.") if drop_last: warnings.warn("drop_last is ignored when train_data is BatchIter.") # concerning issue from https://github.com/pytorch/pytorch/issues/57273 self.pin_memory = kwargs.get('pin_memory', False if parse_version(torch.__version__)==parse_version('1.9') else True) if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的 # device为None if device is not None: warnings.warn("device is ignored when model is nn.parallel.DistributedDataParallel.") device = None # Sampler要是分布式的 if sampler is None: sampler = torch.utils.data.DistributedSampler(train_data) elif not isinstance(sampler, torch.utils.data.DistributedSampler): raise TypeError("When using nn.parallel.DistributedDataParallel, " "sampler must be None or torch.utils.data.DistributedSampler.") # 不能保存模型 if save_path: raise RuntimeError("Saving model in Distributed situation is not allowed right now.") else: # sampler check if sampler is not None and not isinstance(sampler, (Sampler, torch.utils.data.Sampler)): raise ValueError(f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}") if sampler is None: sampler = RandomSampler() elif hasattr(sampler, 'set_batch_size'): sampler.set_batch_size(batch_size) if isinstance(sampler, ConstTokenNumSampler): # 直接使用固定token数量的Sampler assert isinstance(train_data, DataSet), f"When sampler is `ConstTokenNumSampler`, the train_data must" \ f" be `DataSet`." sampler(train_data) train_data = DataSetIter(train_data, batch_size=1, sampler=None, as_numpy=False, num_workers=num_workers, pin_memory=self.pin_memory, drop_last=drop_last, timeout=0, worker_init_fn=None, batch_sampler=sampler) if isinstance(train_data, DataSet): self.data_iterator = DataSetIter(dataset=train_data, batch_size=batch_size, sampler=sampler, num_workers=num_workers, drop_last=drop_last, pin_memory=self.pin_memory) elif isinstance(train_data, BatchIter): self.data_iterator = train_data train_data = train_data.dataset check_code_level = -1 # 强制跳过校验 else: raise TypeError("train_data type {} not support".format(type(train_data))) model.train() self.model = _move_model_to_device(model, device=device) if _model_contains_inner_module(self.model): self._forward_func = self.model.module.forward else: self._forward_func = self.model.forward self.fp16 = fp16 self.verbose = kwargs.get('verbose', 0) # check fp16相关的设置 self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) self.grad_scaler = _grad_scaler() if self.fp16: _can_use_fp16(device=device, model=model, func=self._forward_func) grad_scaler = kwargs.get('grad_scaler', None) if grad_scaler is not None: self.grad_scaler = grad_scaler else: self.grad_scaler = _grad_scaler() self.test_use_fp16 = kwargs.get('test_use_fp16', fp16) self.set_grad_to_none = kwargs.get('set_grad_to_none', True) if check_code_level > -1: # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的field名与模型的输入 # 名是否匹配 dev_dataset = dev_data if isinstance(dev_data, BatchIter): dev_dataset = None warnings.warn("dev_data is of BatchIter type, ignore validation checking.") check_batch_size = min(batch_size, DEFAULT_CHECK_BATCH_SIZE) if isinstance(self.model, nn.DataParallel): _num_devices = len(self.model.device_ids) if batch_size//_num_devices>1: # 如果多卡是每个卡可以分多个数据的,则用每个卡给两个sample check_batch_size = max(len(self.model.device_ids)*2, check_batch_size) else: check_batch_size = max(len(self.model.device_ids), check_batch_size) _check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics, dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level, batch_size=check_batch_size) self.train_data = train_data self.dev_data = dev_data # If None, No validation. self.losser = losser self.metrics = metrics self.n_epochs = int(n_epochs) self.batch_size = int(batch_size) self.save_path = save_path self.print_every = int(print_every) self.validate_every = int(validate_every) if validate_every != 0 else -1 self.best_metric_indicator = None self.best_dev_epoch = None self.best_dev_step = None self.best_dev_perf = None self.n_steps = len(self.data_iterator) * self.n_epochs if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer elif isinstance(optimizer, Optimizer): self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) elif optimizer is None: self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) else: if not (hasattr(optimizer, 'step') and callable(optimizer.step)): raise TypeError("optimizer must have a callable step() function.") else: self.optimizer = optimizer self.logger = logger self.use_tqdm = use_tqdm self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm) self.pbar = None self.print_every = abs(self.print_every) self.kwargs = kwargs if self.dev_data is not None: self.tester = Tester(model=self.model, data=self.dev_data, metrics=self.metrics, batch_size=kwargs.get("dev_batch_size", self.batch_size), device=None, # 由上面的部分处理device verbose=0, use_tqdm=self.test_use_tqdm, sampler=kwargs.get('test_sampler', None), fp16=self.test_use_fp16, num_workers=num_workers, pin_memory=self.pin_memory) self.start_time = None # start timestamp if isinstance(callbacks, Callback): callbacks = [callbacks] self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
[文档] def train(self, load_best_model=True, on_exception='auto', **kwargs): r""" 使用该函数使Trainer开始训练。 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 最好的模型参数。 :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; 'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. :param kwargs: int verbose: 为1时在发生异常时会打印异常发生时batch中的数据在dataset中的index :return dict: 返回一个字典类型的数据, 内含以下内容:: seconds: float, 表示训练时长 以下三个内容只有在提供了dev_data的情况下会有。 best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称, 第二层的key为具体的Metric best_epoch: int,在第几个epoch取得的最佳值 best_step: int, 在第几个step(batch)更新取得的最佳值 """ results = {} verbose = kwargs.get('verbose', 0) if self.n_epochs <= 0: self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.") results['seconds'] = 0. return results try: self._model_device = _get_model_device(self.model) self._mode(self.model, is_test=False) self._load_best_model = load_best_model # 加上millsecond,防止两个太接近的保存 self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')) start_time = time.time() self.logger.info("training epochs started " + self.start_time) self.step = 0 self.epoch = 1 try: self.callback_manager.on_train_begin() self._train() self.callback_manager.on_train_end() except BaseException as e: self.callback_manager.on_exception(e) if verbose>0: self.logger.info(f"The data indices for current batch are: {self.data_iterator.cur_batch_indices}.") if on_exception == 'auto': if not isinstance(e, (CallbackException, KeyboardInterrupt)): raise e elif on_exception == 'raise': raise e if self.dev_data is not None and self.best_dev_perf is not None and load_best_model: model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) load_succeed = self._load_model(self.model, model_name) if load_succeed: self.logger.info("Reloaded the best model.") else: self.logger.info("Fail to reload best model.") if self.dev_data is None and self.save_path is not None: model_name = "_".join([self.model.__class__.__name__, self.start_time]) self._save_model(self.model, model_name) finally: if self.dev_data is not None and self.best_dev_perf is not None: self.logger.info( "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step)) self.logger.info(self.tester._format_eval_results(self.best_dev_perf)) results['best_eval'] = self.best_dev_perf results['best_epoch'] = self.best_dev_epoch results['best_step'] = self.best_dev_step results['seconds'] = round(time.time() - start_time, 2) return results
def _train(self): if not self.use_tqdm: from .utils import _pseudo_tqdm as inner_tqdm else: inner_tqdm = tqdm start = time.time() with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, initial=self.step) as pbar: self.pbar = pbar avg_loss = 0 self.batch_per_epoch = self.data_iterator.num_batches for epoch in range(self.epoch, self.n_epochs + 1): self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping self.callback_manager.on_epoch_begin() for batch_x, batch_y in self.data_iterator: self.step += 1 _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) indices = self.data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y self.callback_manager.on_batch_begin(batch_x, batch_y, indices) prediction = self._data_forward(self.model, batch_x) # edit prediction self.callback_manager.on_loss_begin(batch_y, prediction) with self.auto_cast(): loss = self._compute_loss(prediction, batch_y).mean() loss = loss / self.update_every avg_loss += loss.item() # Is loss NaN or inf? requires_grad = False self.callback_manager.on_backward_begin(loss) self._grad_backward(loss) self.callback_manager.on_backward_end() self._update() self.callback_manager.on_step_end() if self.step % self.print_every == 0: avg_loss = float(avg_loss) / self.print_every if self.use_tqdm: print_output = "loss:{:<6.5f}".format(avg_loss) pbar.update(self.print_every) else: end = time.time() diff = timedelta(seconds=round(end - start)) print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( epoch, self.step, avg_loss, diff) pbar.set_postfix_str(print_output) avg_loss = 0 self.callback_manager.on_batch_end() if (self.validate_every > 0 and self.step % self.validate_every == 0) \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, self.n_steps) # pbar.write(eval_str + '\n') self.logger.info(eval_str) self.logger.info(self.tester._format_eval_results(eval_res)+'\n') # ================= mini-batch end ==================== # if self.validate_every<0 and self.dev_data is not None: # 在epoch结束之后的evaluate eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, self.n_steps) # pbar.write(eval_str + '\n') self.logger.info(eval_str) self.logger.info(self.tester._format_eval_results(eval_res) + '\n') # lr decay; early stopping self.callback_manager.on_epoch_end() # =============== epochs end =================== # if self.dev_data is not None and (self.validate_every>0 and self.n_steps%self.validate_every!=0): eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, self.n_steps) # pbar.write(eval_str + '\n') self.logger.info(eval_str) self.logger.info(self.tester._format_eval_results(eval_res) + '\n') pbar.close() self.pbar = None # ============ tqdm end ============== # def _do_validation(self, epoch, step): self.callback_manager.on_valid_begin() res = self.tester.test() is_better_eval = False if self._better_eval_result(res): if self.save_path is not None: self._save_model(self.model, "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) elif self._load_best_model: self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict().items()} self.best_dev_perf = res self.best_dev_epoch = epoch self.best_dev_step = step is_better_eval = True # get validation results; adjust optimizer self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) return res def _mode(self, model, is_test=False): r"""Train mode or Test mode. This is for PyTorch currently. :param model: a PyTorch model :param bool is_test: whether in test mode or not. """ if is_test: model.eval() else: model.train() def _update(self): r"""Perform weight update on a model. """ if self.step % self.update_every == 0: self.grad_scaler.step(self.optimizer) self.grad_scaler.update() def _data_forward(self, network, x): x = _build_args(self._forward_func, **x) with self.auto_cast(): y = network(**x) if not isinstance(y, dict): raise TypeError( f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") return y def _grad_backward(self, loss): r"""Compute gradient with link rules. :param loss: a scalar where back-prop starts For PyTorch, just do "loss.backward()" """ if (self.step-1) % self.update_every == 0: self._clear_grad(self.optimizer, self.set_grad_to_none) self.grad_scaler.scale(loss).backward() def _clear_grad(self, optimizer, set_to_none=True): param_groups = optimizer.param_groups for group in param_groups: for p in group['params']: if p.grad is not None: if set_to_none: p.grad = None else: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() def _compute_loss(self, predict, truth): r"""Compute loss given prediction and ground truth. :param predict: prediction dict, produced by model.forward :param truth: ground truth dict, produced by batch_y :return: a scalar """ return self.losser(predict, truth) def _save_model(self, model, model_name, only_param=False): r""" 存储不含有显卡信息的state_dict或model :param model: :param model_name: :param only_param: :return: """ if self.save_path is not None: model_path = os.path.join(self.save_path, model_name) if not os.path.exists(self.save_path): os.makedirs(self.save_path, exist_ok=True) if _model_contains_inner_module(model): model = model.module if only_param: state_dict = model.state_dict() for key in state_dict: state_dict[key] = state_dict[key].cpu() torch.save(state_dict, model_path) else: model.cpu() torch.save(model, model_path) model.to(self._model_device) def _load_model(self, model, model_name, only_param=False): # 返回bool值指示是否成功reload模型 if self.save_path is not None: model_path = os.path.join(self.save_path, model_name) if only_param: states = torch.load(model_path) else: states = torch.load(model_path).state_dict() if _model_contains_inner_module(model): model.module.load_state_dict(states) else: model.load_state_dict(states) elif hasattr(self, "_best_model_states"): model.load_state_dict(self._best_model_states) else: return False return True def _better_eval_result(self, metrics): r"""Check if the current epoch yields better validation results. :return bool value: True means current results on dev set is the best. """ indicator, indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) if self.metric_key is None: self.metric_key = indicator is_better = True if self.best_metric_indicator is None: # first-time validation self.best_metric_indicator = indicator_val else: if self.increase_better is True: if indicator_val > self.best_metric_indicator: self.best_metric_indicator = indicator_val else: is_better = False else: if indicator_val < self.best_metric_indicator: self.best_metric_indicator = indicator_val else: is_better = False return is_better @property def is_master(self): r"""是否是主进程""" return True
DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2 def _get_value_info(_dict): # given a dict value, return information about this dict's value. Return list of str strs = [] for key, value in _dict.items(): _str = '' if isinstance(value, torch.Tensor): _str += "\t{}: (1)type:torch.Tensor (2)dtype:{}, (3)shape:{} ".format(key, value.dtype, value.size()) elif isinstance(value, np.ndarray): _str += "\t{}: (1)type:numpy.ndarray (2)dtype:{}, (3)shape:{} ".format(key, value.dtype, value.shape) else: _str += "\t{}: type:{}".format(key, type(value)) strs.append(_str) return strs def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 model_device = _get_model_device(model=model) _iter = DataSetIter(dataset, batch_size=batch_size, sampler=None) for batch_count, (batch_x, batch_y) in enumerate(_iter): _move_dict_value_to_device(batch_x, batch_y, device=model_device) # forward check if batch_count == 0: info_str = "" input_fields = _get_value_info(batch_x) target_fields = _get_value_info(batch_y) if len(input_fields) > 0: info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) info_str += "\n".join(input_fields) info_str += '\n' else: raise RuntimeError("There is no input field.") if len(target_fields) > 0: info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) info_str += "\n".join(target_fields) info_str += '\n' else: info_str += 'There is no target field.' logger.info(info_str) _check_forward_error(forward_func=forward_func, dataset=dataset, batch_x=batch_x, check_level=check_level) refined_batch_x = _build_args(forward_func, **batch_x) pred_dict = model(**refined_batch_x) func_signature = _get_func_signature(forward_func) if not isinstance(pred_dict, dict): raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") # loss check try: loss = losser(pred_dict, batch_y) # check loss output if batch_count == 0: if not isinstance(loss, torch.Tensor): raise TypeError( f"The return value of {_get_func_signature(losser.get_loss)} should be `torch.Tensor`, " f"but got `{type(loss)}`.") if len(loss.size()) != 0: raise ValueError( f"The size of return value of {_get_func_signature(losser.get_loss)} is {loss.size()}, " f"should be torch.size([])") loss.backward() except _CheckError as e: # TODO: another error raised if _CheckError caught pre_func_signature = _get_func_signature(forward_func) _check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, dataset=dataset, check_level=check_level) model.zero_grad() if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: break if dev_data is not None: tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, batch_size=batch_size, verbose=-1, use_tqdm=False) evaluate_results = tester.test() _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) def _check_eval_results(metrics, metric_key, metric_list): # metrics: tester返回的结果 # metric_key: 一个用来做筛选的指标,来自Trainer的初始化 # metric_list: 多个用来做评价的指标,来自Trainer的初始化 if isinstance(metrics, tuple): loss, metrics = metrics if isinstance(metrics, dict): metric_dict = list(metrics.values())[0] # 取第一个metric if metric_key is None: indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] else: # metric_key is set if metric_key not in metric_dict: raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") indicator_val = metric_dict[metric_key] indicator = metric_key else: raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) return indicator, indicator_val