After hours of research and attempts to understand all of the necessary parts required for one to train custom BERT-like model from scratch using HuggingFace’s Transformers library I came to conclusion that existing blog posts and notebooks are always really vague and do not cover important parts or just skip them like they weren’t there - I will give a few examples, just follow the post.

I’ve decided to get my hands dirty and try to train transformer language model (BERT or other) for Polish language. Here you will find all of the missing pieces that I have encountered while doing it.

Many thanks to Egnyte Inc. for providing me with necessary resources to run the training on powerful machine with 4x NVIDIA P100 GPUs.

TL;DR

This post covers:

  • what mistakes to avoid
  • how to prepare own dataset with details - I use Polish Wikipedia and Wolne Lektury datasets
  • how to prepare training configuration
  • how to train new tokenizers using HugginFace’s Rust tokenizers
  • how to use run_language_modeling.py script from HuggingFace library to train your model - what parameters to set, how to deal with the script
  • how to monitor training process
  • how to share the model on HuggingFace’s model hub

Getting the dataset

The most important part when dealing with language models is to have solid dataset with text in the language you will be modeling. In my proof of concept I’ve started with Polish Wikipedia and then extended my dataset with Wolne Lektury (repository of public domain polish books and poems).

Preparing Polish Wikipedia dump

So there is a “just get wikipedia” part that noone covers in details but I find it important. First of all, on the wikipedia dumps page there are many files that you can download. I’ve decided to download dump from 2020-02-20 labeled as:

Articles, templates, media/file descriptions, and primary meta-pages, in multiple bz2 streams, 100 pages per stream

which consisted of 7 files with names like: plwiki-20200220-pages-articles-multistream*.xml*.bz2. Why? That kind of split allows to simpify preprocessing - I don’t have to have lots of RAM to processs smaller files. Another benefit is that I can just hold-out one or two of those files for validation phase after my language model learning is complete.

Every file is a huge XML containing articles with mediawiki-specific markup language. I used this tool to extract well formated JSONs from them: https://github.com/attardi/wikiextractor (clone the repo).

So now I have a list of files for both training and evaluation:

    # train.txt
    plwiki-20200220-pages-articles-multistream1.xml-p1p169750
    plwiki-20200220-pages-articles-multistream2.xml-p169751p510662
    plwiki-20200220-pages-articles-multistream3.xml-p510663p1056310
    plwiki-20200220-pages-articles-multistream4.xml-p1056311p1831508
    plwiki-20200220-pages-articles-multistream5.xml-p1831509p3070393
    plwiki-20200220-pages-articles-multistream6.xml-p4570394p4720470

and

    # eval.txt
    plwiki-20200220-pages-articles-multistream6.xml-p3070394p4570393

To extract them just run:

cat train.txt | xargs -I@  python wikiextractor/WikiExtractor.py @ --bytes=100M --json --output="./plwiki-json/train/@"

After extracting you will get the following structure:

├── eval
│   └── plwiki-20200220-pages-articles-multistream6.xml-p3070394p4570393
│       └── AA
│           ├── wiki_00
│           ├── wiki_01
│           ├── wiki_02
│           └── wiki_03
└── train
    ├── plwiki-20200220-pages-articles-multistream1.xml-p1p169750
    │   └── AA
    │       ├── wiki_00
    │       ├── wiki_01
    │       └── wiki_02
    ├── plwiki-20200220-pages-articles-multistream2.xml-p169751p510662
    │   └── AA
    │       ├── wiki_00
    │       ├── wiki_01
    │       └── wiki_02
    ├── plwiki-20200220-pages-articles-multistream3.xml-p510663p1056310
    │   └── AA
    │       ├── wiki_00
    │       ├── wiki_01
    │       └── wiki_02
    ├── plwiki-20200220-pages-articles-multistream4.xml-p1056311p1831508
    │   └── AA
    │       ├── wiki_00
    │       ├── wiki_01
    │       ├── wiki_02
    │       └── wiki_03
    ├── plwiki-20200220-pages-articles-multistream5.xml-p1831509p3070393
    │   └── AA
    │       ├── wiki_00
    │       ├── wiki_01
    │       ├── wiki_02
    │       └── wiki_03
    └── plwiki-20200220-pages-articles-multistream6.xml-p4570394p4720470
        └── AA
            └── wiki_00

Extracting text from wikipedia

After you get JSON lines from wikipedia dump, you need to transform articles in JSON text into plaintext consumable by the language model training script.

Important!

I’ve probably made a mistake here which is important later on! I first tokenized the text by sentences and outputted each sentence in a separate line because of the ambiguity of available resources about dataset preparation. I think that it affects the quality of both tokenization and training. Do not make that mistake.

After digging throughout the HuggingFace’s github I’ve finally got an answer that there should be one document per line. OK - but what if the text is really long? We will be dealing not only with long wiki articles but later on - with books that have thousands of characters and current SOTA language models have limit of around 512-768 tokens (depends on the model). The possible solution for that is to use sliding window over sentences in the input text. For more details see the following thread: https://github.com/huggingface/transformers/issues/2693

To extract the text and tokenize the input I’ve used nltk with custom polish abbreviations - thanks to Krzysztof Sopyła’s awesome gist: https://gist.github.com/ksopyla/f05fe2f48bbc9de895368b8a7863b5c3.

You can find my extraction notebook here: https://gist.github.com/marrrcin/e383b75a5d0dad42048847d97965e037. I decided to skip it here to not pollute the post.

Key takeaways from preprocessing:

  • tokenize article into sentences
  • use sliding window over sentences to handle long texts - I used sliding window of size 4

Once the extraction completes, you will have training file like this:

Sumatra Północna Sumatra Północna (indonez. "Sumatera Utara") – prowincja w Indonezji w północnej części Sumatry. Obejmuje również wyspę Nias i wyspy Batu leżące na Oceanie Indyjskim. Powierzchnia 70 787 km²; 12 391 tys. mieszkańców (2005); stolica Medan.
"Sumatera Utara") – prowincja w Indonezji w północnej części Sumatry. Obejmuje również wyspę Nias i wyspy Batu leżące na Oceanie Indyjskim. Powierzchnia 70 787 km²; 12 391 tys. mieszkańców (2005); stolica Medan. Większą część powierzchni zajmują góry Barisan (Sinabung 2460 m n.p.m.) z licznymi wulkanami i jeziorem Toba.
Obejmuje również wyspę Nias i wyspy Batu leżące na Oceanie Indyjskim. Powierzchnia 70 787 km²; 12 391 tys. mieszkańców (2005); stolica Medan. Większą część powierzchni zajmują góry Barisan (Sinabung 2460 m n.p.m.) z licznymi wulkanami i jeziorem Toba. Bliżej wybrzeży nizinne tereny w dużym stopniu pokryte bagnami.
Powierzchnia 70 787 km²; 12 391 tys. mieszkańców (2005); stolica Medan. Większą część powierzchni zajmują góry Barisan (Sinabung 2460 m n.p.m.) z licznymi wulkanami i jeziorem Toba. Bliżej wybrzeży nizinne tereny w dużym stopniu pokryte bagnami. Ludność prowincji tworzą Malajowie wyznający islam, Batakowie będący w większości chrześcijanami (gł. protestantami) oraz przybysze z Jawy, Chin, Indii.
Większą część powierzchni zajmują góry Barisan (Sinabung 2460 m n.p.m.) z licznymi wulkanami i jeziorem Toba. Bliżej wybrzeży nizinne tereny w dużym stopniu pokryte bagnami. Ludność prowincji tworzą Malajowie wyznający islam, Batakowie będący w większości chrześcijanami (gł. protestantami) oraz przybysze z Jawy, Chin, Indii. Główne miasta: Medan, Binjai, Pematang Siantar, Tebing Tinggi

Prepare Wolne Lektury dump

For extraction of polish books I have crawled https://wolnelektury.pl/katalog/ site. It’s pretty straightforward process - all of the books can be downloaded in well formatted TXT files. The preprocessing steps are exactly the same as for wiki, namely: tokenize book into sentences, use sliding window to transform long books into smaller chunks without dropping any data - this time I used sliding window of 8 sentences. Gist of preprocessing notebook: https://gist.github.com/marrrcin/bcc115fbadf79eba9d9c8ca711da9e20.

Training new tokenizer

At this point, I’ve decided to go with RoBERTa model. Model you choose determines the tokenizer that you will have to train. For RoBERTa it’s a ByteLevelBPETokenizer, for BERT it would be BertWordPieceTokenizer (both from tokenizers library). Training the tokenizer is super fast thanks to the Rust implementation that guys at HuggingFace have prepared (great job!).


from pathlib import Path

from tokenizers import ByteLevelBPETokenizer
from glob import glob
paths = list(
    glob("your_input_dataset_files/*.txt")
)
# Initialize a tokenizer
tokenizer = ByteLevelBPETokenizer(lowercase=False)

# Customize training
tokenizer.train(files=paths, vocab_size=32000, min_frequency=3, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

# Save files to disk
import os
OUT_DIR = "polish_tokenizer_bpe_32k"
os.makedirs(OUT_DIR, exist_ok=True)
tokenizer.save(OUT_DIR, "pl")

Here, the special_tokens part is super important - those tokens need to be exactly in the same order, because otherwise the model will crash!.

For BERT tokenizer however, you will have:

from tokenizers import BertWordPieceTokenizer, ByteLevelBPETokenizer
# Initialize a tokenizer
tokenizer = BertWordPieceTokenizer(lowercase=False, handle_chinese_chars=False)

Adjusting the run_language_modeling.py script

The version I used and modified comes from v.2.5.1 release of HuggingFace transformers: https://github.com/huggingface/transformers/blob/v2.5.1/examples/run_language_modeling.py

Dataset - check, tokenizer - check. Now it’s time for the model training. Initially, I wanted to use the run_language_modeling.py script from HuggingFace git repo without changes, but it’s it does not support fast Rust-based tokenizers (as for 2020-03-16). Here I mention the minimal changes I made to make it work.

Fast tokenizers support in run_language_modeling.py

In order to plug fast tokenizer I’ve trained above into this script I had to modify the LineByLineTextDataset that’s provided there. The final version looks like this:

class LineByLineTextDataset(Dataset):
    def __init__(self, t: PreTrainedTokenizer, args, file_path: str, block_size=512):
        assert os.path.isfile(file_path)
        logger.info("Creating features from dataset file at %s", file_path)
        
        # -------------------------- CHANGES START
        bert_tokenizer = os.path.join(args.tokenizer_name, "vocab.txt")
        if os.path.exists(bert_tokenizer):
            logger.info("Loading BERT tokenizer")
            from tokenizers import BertWordPieceTokenizer
            tokenizer = BertWordPieceTokenizer(os.path.join(args.tokenizer_name, "vocab.txt"), handle_chinese_chars=False, lowercase=False)
            tokenizer.enable_truncation(512)
        else:
            from tokenizers import ByteLevelBPETokenizer
            from tokenizers.processors import BertProcessing
            logger.info("Loading RoBERTa tokenizer")
            
            tokenizer = ByteLevelBPETokenizer(
                os.path.join(args.tokenizer_name, "vocab.json"),
                os.path.join(args.tokenizer_name, "merges.txt")
            )
            tokenizer._tokenizer.post_processor = BertProcessing(
                ("</s>", tokenizer.token_to_id("</s>")),
                ("<s>", tokenizer.token_to_id("<s>")),
            )
            tokenizer.enable_truncation(max_length=512)

        logger.info("Reading file %s", file_path)
        with open(file_path, encoding="utf-8") as f:
            lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]

        logger.info("Running tokenization")
        self.examples = tokenizer.encode_batch(lines)
        
        # -------------------------- CHANGES END

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return torch.tensor(self.examples[i].ids, dtype=torch.long)

So the only changes are in the way of processing input args.tokenizer_name parameter - if the folder has vocab.txt it’s BERT tokenizer, RoBERTa otherwise. Thanks to the tokenizer.encode_batch(lines) the whole dataset is tokenized in parallel.

Important!

If you don’t have enough RAM to fit the whole dataset into memory, you have to implement your own dataset. My input dataset - Polish wikipedia with sliding window of 4 + Wolne Lektury with sliding window of 8 took ~235GB of RAM.

Preparing training configuration

This part is the most tricky one, as there are a lot of magic numbers in the posts / notebooks I’ve found, even in the official one https://huggingface.co/blog/how-to-train. They just skip it… Start with creating the directory for the configuration, say MyRoBERTaConfig. Copy the tokenizer files and rename them. Be sure to finally have:

    config.json
    tokenizer_config.json
    merges.txt
    vocab.json

in this directory.

Training configuration for RoBERTa

You know what? You don’t have to specify any model-specific parameters if you want to go with the defaults! The only thing that you actually need to specify is the vocabulary size, because it determines the output layer size (Language Model is predicting the tokens after all…). For my RoBERTa I have the following config.json:

{"architectures": ["RobertaForMaskedLM"], "max_position_embeddings": 514, "vocab_size": 32000}

The architectures is self explanatory - this is the language model you will be training, vocab_size comes from the tokenizer you have created.

Important!

The tricky one is max_position_embeddings - no one will tell you that, but if you don’t set this here, the script will just crash. Awesome. I believe that for BERT model it’s not required.

Tokenizer configuration for RoBERTa is simple (tokenizer_config.json). However it’s probably not required, because we’ve overridden the LineByLineTextDataset dataset with tokenizer.

{"max_len": 512}

Training your own RoBERTa language model

Here I will show two types of run_language_modeling.py script execution - one for training from scratch and the second - for fine tuning.

Training from scratch

This is the usual case for any new language model. First, the script:

#!/bin/bash
export TRAIN_FILE=plwiki.train.txt
export EVAL_FILE=plwiki.eval.txt
# run_language_modeling_with_tokenizers.py -- it's the version with support for fast tokenizers, see above
python run_language_modeling_with_tokenizers.py \
    --train_data_file $TRAIN_FILE \
    --eval_data_file $EVAL_FILE \
    --output_dir ./MyRoBERTa \
    --model_type roberta \
    --mlm \
    --config_name ./MyRoBERTaConfig \
    --tokenizer_name ./MyRoBERTaConfig \
    --do_train \
    --do_eval \
    --line_by_line \
    --learning_rate 1e-5 \
    --num_train_epochs 5 \
    --save_total_limit 20 \
    --save_steps 5000 \
    --per_gpu_train_batch_size 8 \
    --warmup_steps=10000 \
    --logging_steps=100 \
    --gradient_accumulation_steps=4 \
    --seed 666 --block_size=512

Arguments explanation: (only the non-trivial)

  • --output_dir - here you will get all of the model checkpoints, saved as often as specified in --save_steps argument
  • --mlm - configure the script to run Masked Language Model task. The script will fortunately yell at you if the --model_type is not appropriate for this task
  • --line_by_line - we say to use LineByLineTextDataset, the script will use my overridden version with tokenizers
  • --save_total_limit - how many checkpoint to keep on the disk - the oldest one will be removed once this limit is exceeded. It’s good to set it here in order to not clog your disk
  • --per_gpu_train_batch_size - setting this parameter might require a few failed runs fro you. If you don’t know how big batch size you can fit into your GPU, just sample your training data, to say 10k lines and launch the script a few times. Monitor the GPU usage by using watch -n 1 nvidia-smi command. Set it to as high number as possible. For Nvidia P100 and 512 --block_size, it’s 8.
  • --gradient_accumulation_steps - this parameter configures how many steps you will pass in forward pass of the network before running backward pass, it’s useful when you GPU cannot handle large batches
  • --warmup_steps - configures learning rate scheduler; as per BERT paper recommendation - it’s good to do warmup and then warmdown. Number of steps depends on your training dataset size
  • --logging_steps - tells how often to store Tensorboard logs - the script only outputs loss and learning rate

Training from checkpoint (finetuning)

After I’ve pretrained my model on wikipedia only, I decided to finetune it further on the Wolne Lektury dataset. Training script has produced a few snapshots, from which every one has the following structure:

├── checkpoint-20000
│   ├── config.json
│   ├── merges.txt
│   ├── optimizer.pt
│   ├── pytorch_model.bin
│   ├── scheduler.pt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   ├── training_args.bin
│   └── vocab.json

Important!

If you’re changing the input dataset, you cannot use existing scheduler.pt and optimizer.pt as you’re not “resuming” the training. You should remove (but backup first!) those files if you want to fine tune on another dataset.'

OK, now the script itself:

#!/bin/bash
export TRAIN_FILE=wolne-lektury.sliding8.txt
export EVAL_FILE=plwiki.eval.txt

python run_language_modeling_with_tokenizers.py \
    --train_data_file $TRAIN_FILE \
    --eval_data_file $EVAL_FILE \
    --output_dir ./MyRoBERTa_with_another_dataset \
    --model_type roberta \
    --model_name_or_path ./MyRoBERTa/checkpoint-30000 \
    --mlm \
    --config_name ./MyRoBERTaConfig \
    --tokenizer_name ./MyRoBERTaConfig \
    --do_train \
    --do_eval \
    --line_by_line \
    --learning_rate 5e-5 \
    --num_train_epochs 3 \
    --save_total_limit 20 \
    --save_steps 5000 \
    --per_gpu_train_batch_size 8 \
    --warmup_steps=5000 \
    --logging_steps=100 \
    --gradient_accumulation_steps=4 \
    --mlm_probability=0.2 \
    --seed 666 --block_size=512

Arguments - notable mentions:

  • --mlm_probability=0.2 - this parameter controls the percentage of the tokens you mask during training; default is 0.15, I’ve decided to change it to make the training more difficult to the model
  • --model_name_or_path ./MyRoBERTa/checkpoint-30000 - this should point to the checkpoint of your previously trained model; remember the note above!

Training phase

I highly recommend you to launch the script using screen tool, which allows to easily spawn and manage multiple SSH sessions - you will be able to disconnect from the server while the training runs. And it will run for days… so brace yourself. You can use screen -UL command (-U for enabling UTF-8 mode for progress bars and -L to log session output to file - it’s helpful when something crashes).

During training, the script will output Tensorboard logs into runs directory. You can monitor them by running tensorboard command:

tensorboard --logdir=runs

It will launch small webserver on localhost:6006 and you will be able to monitor the training.

Training summary for Polish RoBERTa a.k.a PolBERTa

I’ve run my training in three phases:

  1. From scratch on Polish Wikipedia only (1 sentence per line) for 370k steps using learning rate 1e-5 with 10k warmup steps.
  2. Fine tuning on Wolne Lektury only (8 sentences per line) for 60k steps, starting after checkpoint from 1. Learning rate 5e-5 with 5k warmup steps.
  3. Fine tuning on Wolne Lektury (8 sentences per line) and Polish Wikipedia (4 sentences per line) for 30k steps. Learning rate 5-e5 with 10k warmup steps.

Tensorboard for all of the runs can be viewed here https://tensorboard.dev/experiment/aYrAE9uMTxGKRtLUOFjprg/

PolBERTa tensorboard

Exporting the model

If you want to share the model with the NLP community (which I highly encourage you to do!) you need to export it in appropriate format. Prepare the path for your model’s latest checkpoint and then run the following code:

from transformers import AutoModelWithLMHead, AutoTokenizer
import os
directory = "/path/to/your/model/checkpoint-30000"
model = AutoModelWithLMHead.from_pretrained(directory)
tokenizer = AutoTokenizer.from_pretrained(directory)
out = "MyBERTa-base-cased-v1"
os.makedirs(out, exist_ok=True)
model.save_pretrained(out)
tokenizer.save_pretrained(out)

Then use transformers-cli util to upload the model:

transformers-cli upload ./MyBERTa-base-cased-v1/

Using my pre-trained model

I’ve uploaded my pre-trained RoBERTa to HuggingFace’s model hub: marrrcin/PolBERTa-base-polish-cased-v1 Usage:

from transformers import AutoModelWithLMHead, AutoTokenizer, pipeline

model = AutoModelWithLMHead.from_pretrained("marrrcin/PolBERTa-base-polish-cased-v1")
tokenizer = AutoTokenizer.from_pretrained("marrrcin/PolBERTa-base-polish-cased-v1")

fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)
fill_mask("Wiadomym jest mi, że masz lat blisko czterdzieści, wyglądasz na blisko trzydzieści, wyobrażasz sobie, że masz nieco ponad <mask>, a postępujesz tak jakbyś miał niecałe dziesięć.")

Output:

[{'sequence': '<s> Wiadomym jest mi, że masz lat blisko czterdzieści, wyglądasz na blisko trzydzieści, wyobrażasz sobie, że masz nieco ponad sto, a postępujesz tak jakbyś miał niecałe dziesięć.</s>',
  'score': 0.1013106107711792,
  'token': 1675},
 {'sequence': '<s> Wiadomym jest mi, że masz lat blisko czterdzieści, wyglądasz na blisko trzydzieści, wyobrażasz sobie, że masz nieco ponad tysiąc, a postępujesz tak jakbyś miał niecałe dziesięć.</s>',
  'score': 0.09420681744813919,
  'token': 31144},
 {'sequence': '<s> Wiadomym jest mi, że masz lat blisko czterdzieści, wyglądasz na blisko trzydzieści, wyobrażasz sobie, że masz nieco ponad połowę, a postępujesz tak jakbyś miał niecałe dziesięć.</s>',
  'score': 0.0891065001487732,
  'token': 13627},
 {'sequence': '<s> Wiadomym jest mi, że masz lat blisko czterdzieści, wyglądasz na blisko trzydzieści, wyobrażasz sobie, że masz nieco ponad dwadzieścia, a postępujesz tak jakbyś miał niecałe dziesięć.</s>',
  'score': 0.07204979658126831,
  'token': 18811},
 {'sequence': '<s> Wiadomym jest mi, że masz lat blisko czterdzieści, wyglądasz na blisko trzydzieści, wyobrażasz sobie, że masz nieco ponad godzinę, a postępujesz tak jakbyś miał niecałe dziesięć.</s>',
  'score': 0.059559762477874756,
  'token': 20775}]

Follow up and TODO

Here are a few follow up things and my planned future work on Polish RoBERTa:

  • prepare model card for HuggingFace's model hub
  • evaluate PolBERTa on PolEmo2.0 and other downstream tasks
  • re-train PolBERTa again, using bigger vocab size for tokenizer
  • fine tune model further on another open datasets

If you find any bug in the post or something is not clear - feel free to post a comment, I will be glad to improve this short practical guide.

Thanks to…

I would like to thank Egnyte Inc. for providing me with required resources to train the language model and also Darek Kłeczek (creator of original PolBERT) for tips related to training BERT.

Comments