Before we start with extension, let’s try to understand how fairseq
training works for seq2seq
models. In this tutorial I will use only hydra-train
module to make it possible load yaml
configs.
Follow installation guide on github page.
Suppose you want to train translation
model using hydra training. The command syntax is the following:
uname@host ~ $ fairseq-hydra-train --config-dir . --config-name config.yaml
Obviously you need to prepare config.yaml
first. All missing parameters will be reaplced by defaults. Defaults can be find at fairseq/dataclass/configs.py
You can also see avaliable parameters there.
# @package _group_
task:
_name: translation
data: fairseq-bin-data
source_lang: ro
target_lang: en
eval_bleu: true
eval_bleu_args: '{"beam":1}'
eval_bleu_detok: "moses"
eval_bleu_remove_bpe: subword_nmt
eval_bleu_print_samples: false
criterion:
_name: cross-entropy
model:
_name: transformer_base
decoder:
output_dim: 100
learned_pos: true
encoder:
learned_pos: true
dropout: 0.3
optimizer:
_name: adam
adam_betas: (0.9,0.98)
lr_scheduler:
_name: inverse_sqrt
warmup_updates: 10000
warmup_init_lr: 1e-07
dataset:
max_tokens: 4000
validate_interval_updates: 2000
optimization:
lr: [0.0005]
update_freq: [16]
max_update: 50000
stop_min_lr: 1e-09
checkpoint:
best_checkpoint_metric: bleu
maximize_best_checkpoint_metric: true
common:
wandb_project: ?
log_format: simple
log_interval: 100
You can replace or add parameters on the fly while keeping base config. If parameter not in a base config, use +
to add it:
#checkpoint.no_save is not in config.yaml so add +
#common.wandb_project is in base config, so no need to add +
uname@host ~ $ fairseq-hydra-train --config-dir . --config-name config.yaml common.wandb_project=myproject +checkpoint.no_save=True
Look at the diagram to see what happens when you call training! (warining: it’s huge).
The easiest way to extend fairseq
is to fork fairseq
repository and add your files (model/criterion/task etc).
@register_task(name, dataclass=TaskConfigClass)
@register_model(name, dataclass=ModelConfigClass)
@register_criterion(name, dataclass=CriterionConfigClass)
fairseq/model/
folder. E.g., we create a fairseq/model/my_new_fancy_model.py
transformer_base
#dataclass defines all additional params for the model
@dataclass
class MyNewFancyModelConfig(TransformerConfig):
new_param: Optional[str] = field(
default=None,
metadata={"help": "It will be needed in my model"},
)
#very important! don't forget to register your model with unique name
@register_model("fancy_model", dataclass=ContinuousTransformerConfig)
class MyNewFancyModel(TransformerModelBase):
def __init__(sels, cfg, encoder, decoder):
super().__init__(encoder, decoder)
self.cfg = cfg
#now you can access the params from config
#e.g. you want to access new_param from MyNewFancyModelConfig
#self.cfg.new_param viola!
@classmethod
def build_decoder(cls, cfg, tgt_dict, embed_tokens):
decoder = MyNewFancyDecoderBecauseIcan(
cfg,
tgt_dict,
embed_tokens,
)
return decoder
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
return_all_hiddens: bool = True,
features_only: bool = True,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
#all magic of your model happens here
pass
model._name=fancy_model
Similarly to the model, you can create new criterion under fairseq/criterions
from dataclasses import dataclass
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
@dataclass
class MyNewCriterionConfig(FairseqDataclass):
pass
@register_criterion("my_new_criterion", dataclass=MyNewCriterionConfig)
class MyNewCriterion(FairseqCriterion):
def forward(self, model, sample, reduce=True):
#loss calculation happens here
#loss, sample_size, logging_output = f()
return loss, sample_size, logging_output
criterion._name=my_new_criterion
__init__
file of the respective moduleSo you decided to create your own lib built on top of fairseq
. Good luck (/jk). There are several reasons you might want to do it, e.g.
To make your own lib on top offairseq
you need mainly replace fairseq
imports with your own imports. That might require “monkeypatching”.
Example of patched hydra-train
:
"""
First import main hydra-train module from fairseq.
And all modules needed from your own lib
"""
import fairseq_cli.hydra_train as fairseq_hydra_train
from cdgm_textgen.dataclass.utils import omegaconf_no_object_check as cdgm_omegaconf_no_object_check
from cdgm_textgen.dataclass.configs import CDGMTextgenConfig
from cdgm_textgen.dataclass.initialize import add_defaults as cdgm_add_defaults
from cdgm_textgen.dataclass.initialize import hydra_init as cdgm_hydra_init
from cdgm_textgen_cli.train import main as cdgm_pre_main
"""
"replacing" functions/modules of hydra_train
"""
fairseq_hydra_train.omegaconf_no_object_check = cdgm_omegaconf_no_object_check
fairseq_hydra_train.FairseqConfig = CDGMTextgenConfig
fairseq_hydra_train.add_defaults = cdgm_add_defaults
fairseq_hydra_train.hydra_init = cdgm_hydra_init
fairseq_hydra_train.pre_main = cdgm_pre_main
#call original functions
def cli_main():
fairseq_hydra_train.cli_main()
if __name__ == "__main__":
fairseq_hydra_train.cli_main()
fairseq
modules when needed. E.g., new criterion with cosine similarity is inhereted from CrossEntropyCriterion
from fairseqimport math
from dataclasses import field
from typing import Optional
import torch
import torch.nn.functional as F
from torch.distributions import multivariate_normal
from dataclasses import dataclass
from fairseq import metrics, utils
from fairseq.criterions import register_criterion, FairseqCriterion
from fairseq.criterions.cross_entropy import CrossEntropyCriterion, CrossEntropyCriterionConfig
import power_spherical
from fairseq.dataclass import FairseqDataclass
@register_criterion("cosine_ar_criterion")
class CosineARCriterion(CrossEntropyCriterion):
def forward(self, model, sample, reduce=True):
pass