Chinese Text Simplification for Reading Comprehension (Part 3)

Part 3 Summary

I describe my process of fine-tuning and applying a Chinese language BART model to achieve generalized text simplification.

Fine-tuning a BART model for sentence-level text simplification

Bidirectional and Auto-Regressive Transformer (BART) is a sequence-to-sequence model proposed by Lewis et al. [1] which uses an encoder and decoder for NLP tasks. For some in-depth explanations, check out these resources for sequence-to-sequence models, BERT, and BART. I will be fine-tuning an existing BART model for a sequence generation task: generating a simplified Chinese sentence based on a complex sentence input!

BART model as explained in Lewis et al. [1]

Fine-tuning BART for text simplification tasks has been demonstrated in English [2] as well as Chinese [3] and has shown performance better than large language models like GPT 3.5 when properly fine-tuned. I will be using a BART model trained on Chinese language data, CPT/Chinese BART, available on HuggingFace: https://huggingface.co/fnlp/bart-base-chinese. The model was implemented from the architecture described in Shao et al. [4]

Chinese BART is trained from masking sequences to predict the masked word; it is not trained for text simplification and will need to be fine-tuned. Let's get into how to do that!

Data for fine-tuning

The concept of fine-tuning is simple enough -- you use the pre-trained model with its preexisting model weights and biases, and continue training the model from its current state using task-specific data. In this case, I want to train Chinese BART with text simplification parallel data - a complex sentence input, with a simple sentence to be predicted as the output.

As I discussed in part 1 of this project, the MCTS paper [3] provides a dataset of 691,474 pseudo-simplified sentence pairs, generated by machine-translating complex Chinese sentences from the People's Daily Corpus to English, applying English TS tools, and retranslating the simplified sentences back to Chinese.

Since this dataset is "pseudo" data coming from machine translation, the fluency of the output is unfortunately not guaranteed. However, at the time of writing there is a significant lack of datasets for Chinese language TS compared to other languages like English, so the MCTS dataset is one of the best options available.

Fine-tuning Chinese BART

Since our BART model is prepared on HuggingFace, we can use the HuggingFace transformers package and framework for our fine-tuning. Let's start by preparing our data. Rather than input the MCTS data directly into the tokenizer, I want to add tag the sentence pairs with some custom tokens for HSK vocabulary.

Let's start by importing some base packages, loading our model, tokenizer, and the HSK vocabulary dataset that I processed in Part 2:


import numpy as np
import jieba
import pickle
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
from datasets import Dataset
from transformers import BertTokenizer, BartForConditionalGeneration
tokenizer = BertTokenizer.from_pretrained("fnlp/bart-base-chinese")
model = BartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese")

with open("HSK_levels.pickle", 'rb') as handle:
	HSK_dict = pickle.load(handle)
									

Now define our preprocessing function to add HSK tokens for HSK levels 4, 5, 6, and 7-9:

	
def tokenize_with_HSK(sentence, HSK_dict):
    split_sentence = jieba.lcut(sentence)
    HSK_sentence = ""
    for word in split_sentence:
        score = HSK_dict.get(word, 0)
        if score>3:
            HSK_sentence += f"{word}[{score}]"
        else:
            HSK_sentence += f"{word}"
    return HSK_sentence

tokenizer.add_tokens(["[4]", "[5]", "[6]", "[7]"])
model.resize_token_embeddings(len(tokenizer))

def preprocess_data(filename: str, start: int, stop: int):
    lines_HSK = []
    with open(filename, encoding="utf8") as f:
        lines_orig = f.read().splitlines()
        for line in lines_orig[start:stop]:
            lines_HSK.append(tokenize_with_HSK(line, HSK_dict))
    return lines_HSK
									

Now we can load the pseudo data from the MCTS paper, run it through our preprocessing function, and convert it into a HuggingFace Dataset object. Here, I am using 500,000 sentence pairs, with an 80/20 split for training and evaluation. As the dataset is already shuffled, we can load the train and eval set directly:


start = 0
stop = 500000
split = 400000

lines_complex = preprocess_data('zh_selected.ori', start, stop)
lines_simple = preprocess_data('zh_selected.sim', start, stop)

data_dict = {'complex': lines_complex[start:split], 'simple': lines_simple[start:split]}
ds_train = Dataset.from_dict(data_dict)
data_dict = {'complex': lines_complex[split:stop], 'simple': lines_simple[split:stop]}
ds_eval = Dataset.from_dict(data_dict)
									

Here's a look at the first training example in the Dataset. You can see that the HSK tokens were added to a few words, which are at HSK level 4:


ds_train['complex'][0]: '75公斤级比赛3个项目[4]的第一名均为中国选手李顺柱获得[4]。'
(The first place in all three events[4] of the 75 kg competition was won[4] by China's Li Shunzhu.)

ds_train['simple'][0]: '中国选手李顺珠在75公斤级三个项目[4]中获得[4]第一名。'
(China's Li Shunzhu won[4] first place in three events[4] in the 75 kg category.)
									

Now, tokenize the preprocessed data. Here I use a max_length of 128, with truncation for larger inputs and padding for smaller inputs:


# tokenize data
max_length = 128
def batch_tokenize_data(data):
    inputs = [example for example in data['complex']]
    targets = [example for example in data['simple']]

    model_inputs = tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True)
    labels = tokenizer(targets, max_length=max_length, padding='max_length', truncation=True)

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

tokenized_data_train = ds_train.map(batch_tokenize_data, batched=True)
tokenized_data_eval = ds_eval.map(batch_tokenize_data, batched=True)
									

Now, set up the trainer and begin model fine-tuning with trainer.train():


training_args = TrainingArguments(
    output_dir="./bart_simplification",
    evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    eval_accumulation_steps = 1,
    num_train_epochs=5,
    weight_decay=0.01,
    save_total_limit=2,
    logging_steps=500
)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = tokenized_data_train,
    eval_dataset = tokenized_data_eval,
    tokenizer = tokenizer
)								

trainer.train()
									

I used Google Colab with a A100 GPU to perform the training step. Let's look at the loss during training:

Training and evaluation loss during fine-tuning

Five epochs looks like a good stopping point, as the validation loss is leveling off and I want to avoid overfitting. Now that we have a trained model, let's try making a prediction on the same example I showed earlier:


from transformers import Text2TextGenerationPipeline
text2text_generator = Text2TextGenerationPipeline(model, tokenizer)
output = text2text_generator(sentence, max_length=128, do_sample=False)[0]['generated_text'].replace(" ","")
									

Example ('complex'): '75公斤级比赛3个项目的第一名均为中国选手李顺柱获得。'
(The first place in all three events of the 75 kg competition was won by China's Li Shunzhu.)

Example ('simple'): '中国选手李顺珠在75公斤级三个项目[4]中获得[4]第一名。'
(China's Li Shunzhu won first place in three events in the 75 kg category.)	

Fine-tuned BART output: '75公斤级三个项目的第一名均由中国选手李顺柱获得。'
(The first place in all three events in the 75kg category went to China's Li Shunzhu.)
									

Not bad! It looks like the output of the fine-tuned model is at least fluent and can produce reasonable output. But it can be pretty difficult to tell if the output is sufficiently simple, or if it requires further simplification.

Conclusion

In Part 3, I've described a way to preprocess and tokenize complex/simple sentence pairs from a large dataset and apply them for fine-tuning on the Chinese BART model. The output looks fluent and captures the meaning of the original complex sentence, but how can we tell if the model is actually simplifying the sentence rather than simply rewording it? Let's tackle that in Part 4, where I evaluate the performance of this BART model as well as our LS pipeline described in Part 2.

Continued in Part 4!

References

[1] Lewis et al., "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension," https://doi.org/10.48550/arXiv.1910.13461 (2019)

[2] Sun et al., "Teaching the Pre-trained Model to Generate Simple Texts for Text Simplification," https://doi.org/10.48550/arXiv.2305.12463 (2023)

[3] Chong et al., "MCTS: A Multi-Reference Chinese Text Simplification Dataset," https://doi.org/10.48550/arXiv.2306.02796 (2024)

[4] Shao et al., "CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation," https://doi.org/10.48550/arXiv.2109.05729 (2021)