Source code for synapse.ml.dl.LitDeepTextModel

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import inspect

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning.utilities import _module_available

_TRANSFORMERS_AVAILABLE = _module_available("transformers")
if _TRANSFORMERS_AVAILABLE:
    import transformers

    _TRANSFORMERS_EQUAL_4_32_1 = transformers.__version__ == "4.32.1"
    if _TRANSFORMERS_EQUAL_4_32_1:
        from transformers import AutoModelForSequenceClassification
    else:
        raise RuntimeError(
            "transformers should be == 4.32.1, found: {}".format(
                transformers.__version__
            )
        )
else:
    raise ModuleNotFoundError("module not found: transformers")


[docs]class LitDeepTextModel(pl.LightningModule): def __init__( self, checkpoint, text_col, label_col, num_labels, additional_layers_to_train, optimizer_name, loss_name, learning_rate=None, train_from_scratch=True, ): """ :param checkpoint: Checkpoint for pre-trained model. This is expected to be a checkpoint you could find on [HuggingFace](https://huggingface.co/models) and is of type `AutoModelForSequenceClassification`. :param text_col: Text column name. :param label_col: Label column name. :param num_labels: Number of labels for classification. :param additional_layers_to_train: Additional number of layers to train on. For Deep text model we'd better choose a positive number for better performance. :param optimizer_name: Name of the optimizer. :param loss_name: Name of the loss function. :param learning_rate: Learning rate for the optimizer. :param train_from_scratch: Whether train the model from scratch or not. If this is set to true then additional_layers_to_train param will be ignored. Default to True. """ super(LitDeepTextModel, self).__init__() self.checkpoint = checkpoint self.text_col = text_col self.label_col = label_col self.num_labels = num_labels self.additional_layers_to_train = additional_layers_to_train self.optimizer_name = optimizer_name self.loss_name = loss_name self.learning_rate = learning_rate self.train_from_scratch = train_from_scratch self._check_params() self.save_hyperparameters( "checkpoint", "text_col", "label_col", "num_labels", "additional_layers_to_train", "optimizer_name", "loss_name", "learning_rate", "train_from_scratch", ) def _check_params(self): try: # TODO: Add other types of models here self.model = AutoModelForSequenceClassification.from_pretrained( self.checkpoint, num_labels=self.num_labels ) self._update_learning_rate() except Exception as err: raise ValueError( f"No checkpoint {self.checkpoint} found: {err=}, {type(err)=}" ) if self.loss_name.lower() not in F.__dict__: raise ValueError("No loss function: {} found".format(self.loss_name)) self.loss_fn = F.__dict__[self.loss_name.lower()] optimizers_mapping = { key.lower(): value for key, value in optim.__dict__.items() if inspect.isclass(value) and issubclass(value, optim.Optimizer) } if self.optimizer_name.lower() not in optimizers_mapping: raise ValueError("No optimizer: {} found".format(self.optimizer_name)) self.optimizer_fn = optimizers_mapping[self.optimizer_name.lower()]
[docs] def forward(self, **inputs): return self.model(**inputs)
[docs] def configure_optimizers(self): if not self.train_from_scratch: # Freeze those weights for p in self.model.base_model.parameters(): p.requires_grad = False self._fine_tune_layers() params_to_update = filter(lambda p: p.requires_grad, self.model.parameters()) return self.optimizer_fn(params_to_update, self.learning_rate)
def _fine_tune_layers(self): if self.additional_layers_to_train < 0: raise ValueError( "additional_layers_to_train has to be non-negative: {} found".format( self.additional_layers_to_train ) ) # base_model contains the real model to fine tune children = list(self.model.base_model.children()) added_layer, cur_layer = 0, -1 while added_layer < self.additional_layers_to_train and -cur_layer < len( children ): tunable = False for p in children[cur_layer].parameters(): p.requires_grad = True tunable = True # only tune those layers contain parameters if tunable: added_layer += 1 cur_layer -= 1 def _update_learning_rate(self): ## TODO: add more default values for different models if not self.learning_rate: if "bert" in self.checkpoint: self.learning_rate = 5e-5 else: self.learning_rate = 0.01
[docs] def training_step(self, batch, batch_idx): loss = self._step(batch, False) self.log("train_loss", loss) return loss
def _step(self, batch, validation): inputs = batch outputs = self(**inputs) loss = outputs.loss return loss
[docs] def validation_step(self, batch, batch_idx): loss = self._step(batch, True) self.log("val_loss", loss)
[docs] def validation_epoch_end(self, outputs): avg_loss = ( torch.stack([x["val_loss"] for x in outputs]).mean() if len(outputs) > 0 else float("inf") ) self.log("avg_val_loss", avg_loss)
[docs] def test_step(self, batch, batch_idx): loss = self._step(batch, False) self.log("test_loss", loss) return loss