about « all posts

Extending Fairseq: Incomplete Guide

Oct 17 2022 · 4 min read
#fairseq #nlp #machine-learning

Table of content

  1. Fairseq How To
  2. Easy Mode
  3. Not So Easy Mode

Fairseq How To

Before we start with extension, let’s try to understand how fairseq training works for seq2seq models. In this tutorial I will use only hydra-train module to make it possible load yaml configs.

Install fairseq

Follow installation guide on github page.

Training with Hydra

Suppose you want to train translation model using hydra training. The command syntax is the following:

[email protected] ~ $ fairseq-hydra-train --config-dir . --config-name config.yaml

Obviously you need to prepare config.yaml first. All missing parameters will be reaplced by defaults. Defaults can be find at fairseq/dataclass/configs.py You can also see avaliable parameters there.

Translation Config Example
# @package _group_
  _name: translation
  data: fairseq-bin-data
  source_lang: ro
  target_lang: en
  eval_bleu: true
  eval_bleu_args: '{"beam":1}'
  eval_bleu_detok: "moses"
  eval_bleu_remove_bpe: subword_nmt
  eval_bleu_print_samples: false
  _name: cross-entropy
  _name: transformer_base
    output_dim: 100
    learned_pos: true
    learned_pos: true
  dropout: 0.3
  _name: adam
  adam_betas: (0.9,0.98)
  _name: inverse_sqrt
  warmup_updates: 10000
  warmup_init_lr: 1e-07
  max_tokens: 4000
  validate_interval_updates: 2000
  lr: [0.0005]
  update_freq: [16]
  max_update: 50000
  stop_min_lr: 1e-09
  best_checkpoint_metric: bleu
  maximize_best_checkpoint_metric: true
  wandb_project: ?
  log_format: simple
  log_interval: 100

You can replace or add parameters on the fly while keeping base config. If parameter not in a base config, use + to add it:

#checkpoint.no_save is not in config.yaml so add +
#common.wandb_project is in base config, so no need to add +
[email protected] ~ $ fairseq-hydra-train --config-dir . --config-name config.yaml common.wandb_project=myproject +checkpoint.no_save=True

Fairseq Training Flow

Look at the diagram to see what happens when you call training! (warining: it’s huge).

Fairseq Training Flow

Extension Easy Mode

The easiest way to extend fairseq is to fork fairseq repository and add your files (model/criterion/task etc).

@register_task(name, dataclass=TaskConfigClass)
@register_model(name, dataclass=ModelConfigClass)
@register_criterion(name, dataclass=CriterionConfigClass)

Add new model

  1. Add file under the fairseq/model/ folder. E.g., we create a fairseq/model/my_new_fancy_model.py
  2. Open your file in editor and add model code! Let’s assume you extend transformer_base
#dataclass defines all additional params for the model
class MyNewFancyModelConfig(TransformerConfig):
    new_param: Optional[str] = field(
        metadata={"help": "It will be needed in my model"},
#very important! don't forget to register your model with unique name

@register_model("fancy_model", dataclass=ContinuousTransformerConfig) 
class MyNewFancyModel(TransformerModelBase):
    def __init__(sels, cfg, encoder, decoder):
        super().__init__(encoder, decoder)
        self.cfg = cfg
        #now you can access the params from config
        #e.g. you want to access new_param from MyNewFancyModelConfig
        #self.cfg.new_param viola!

    def build_decoder(cls, cfg, tgt_dict, embed_tokens):
        decoder = MyNewFancyDecoderBecauseIcan(

        return decoder
    def forward(
            return_all_hiddens: bool = True,
            features_only: bool = True,
            alignment_layer: Optional[int] = None,
            alignment_heads: Optional[int] = None,
        #all magic of your model happens here
  1. Call new model by passing model._name=fancy_model

Add new criterion

Similarly to the model, you can create new criterion under fairseq/criterions

  1. Create your criterion file
from dataclasses import dataclass
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass

class MyNewCriterionConfig(FairseqDataclass):

@register_criterion("my_new_criterion", dataclass=MyNewCriterionConfig)
class MyNewCriterion(FairseqCriterion):
  def forward(self, model, sample, reduce=True):
    #loss calculation happens here
    #loss, sample_size, logging_output = f()
    return loss, sample_size, logging_output
  1. Implement loss calculation
  2. Run model with criterion._name=my_new_criterion


  1. You can add tasks and so on in the similar manner.
  2. Do not forget to also import your new model/criterion in the __init__ file of the respective module

Extension Not So Easy Mode

So you decided to create your own lib built on top of fairseq. Good luck (/jk). There are several reasons you might want to do it, e.g.

To make your own lib on top offairseq you need mainly replace fairseq imports with your own imports. That might require “monkeypatching”.

Example of patched hydra-train:

First import main hydra-train module from fairseq.
And all modules needed from your own lib
import fairseq_cli.hydra_train as fairseq_hydra_train 
from cdgm_textgen.dataclass.utils import omegaconf_no_object_check as cdgm_omegaconf_no_object_check
from cdgm_textgen.dataclass.configs import CDGMTextgenConfig
from cdgm_textgen.dataclass.initialize import add_defaults as cdgm_add_defaults
from cdgm_textgen.dataclass.initialize import hydra_init as cdgm_hydra_init
from cdgm_textgen_cli.train import main as cdgm_pre_main

"replacing" functions/modules of hydra_train
fairseq_hydra_train.omegaconf_no_object_check = cdgm_omegaconf_no_object_check
fairseq_hydra_train.FairseqConfig = CDGMTextgenConfig
fairseq_hydra_train.add_defaults = cdgm_add_defaults
fairseq_hydra_train.hydra_init = cdgm_hydra_init
fairseq_hydra_train.pre_main = cdgm_pre_main

#call original functions
def cli_main():

if __name__ == "__main__":
import math
from dataclasses import field
from typing import Optional

import torch
import torch.nn.functional as F
from torch.distributions import multivariate_normal

from dataclasses import dataclass
from fairseq import metrics, utils
from fairseq.criterions import register_criterion, FairseqCriterion
from fairseq.criterions.cross_entropy import CrossEntropyCriterion, CrossEntropyCriterionConfig

import power_spherical

from fairseq.dataclass import FairseqDataclass

class CosineARCriterion(CrossEntropyCriterion):
    def forward(self, model, sample, reduce=True):