Training thêm data cho GPT-2 model, version thử nghiệm thực tế :smile
Mở
Vừa qua cũng thấy nhiều bạn tìm hiểu về chatGPT, các kiến thức chủ đạo như mô hình Tranformer https://viblo.asia/p/tu-transformer-den-language-model-bai-1-bat-dau-voi-kien-truc-mo-hinh-transformer-38X4EN1gJN2 . Mình dù không hiểu nhiều lắm, nhưng dựa trên các thông tin trên internet, cũng vọc vạch làm một số demo nhỏ cho việc train model GPT-2, ở mức newbie , hy vọng a.e chạy phát cho vui để lấy cảm hứng học được nhiều hơn
Load the dataframe
import pandas as pd
df = pd.read_csv('../data/Restaurant_Reviews.tsv', sep='\t')
df = df.rename(columns={'Review': 'text'})
df
Clean text một chút.
import nltk
nltk.download('stopwords')
import re
from nltk.corpus import stopwords
def clean_text(text):
# Make text lowercase
text = text.lower()
# Remove text in square brackets
text = re.sub('\[.*?\]', '', text)
# Remove links
text = re.sub('https?://\S+|www\.\S+', '', text)
# Remove punctuation
text = re.sub('[^a-zA-Z0-9\s]+', '', text)
# Remove words containing numbers
text = re.sub('\w*\d\w*', '', text)
# Remove stop words
stop_words = set(stopwords.words('english'))
words = text.split()
filtered_words = [word for word in words if word not in stop_words]
text = ' '.join(filtered_words)
# Remove extra whitespace
text = re.sub('\s+', ' ', text).strip()
return text
# Apply the clean_text function to all text in the 'text' column
df['text'] = df['text'].apply(clean_text)
# Show the updated dataframe
df.head()
Phần tách thành những token nhỏ hơn
import re
contents = df['text'].values.tolist()
def santilize(x):
t = x.split(' ')
new_list = [item for item in t if item is not None]
return new_list
content_tokens = list(map(santilize, contents))
print('content_tokens', content_tokens[0])
Load model pretrained của GPT2
from transformers import AutoTokenizer, AutoModelWithLMHead
modelMaskedLM = AutoModelWithLMHead.from_pretrained('gpt2')
tokenizerVI = AutoTokenizer.from_pretrained('gpt2')
Add thêm vocabulary mới cho model.
def flatten_list(lst):
flattened = []
for item in lst:
if isinstance(item, list):
flattened.extend(flatten_list(item))
else:
flattened.append(item)
return flattened
sentence_tokens = flatten_list(content_tokens)
print('sentence_tokens', set(sentence_tokens[0:30]))
new_tokens = set(sentence_tokens) - set(tokenizerVI.get_vocab().keys())
print('length before add:', len(tokenizerVI.vocab))
tokenizerVI.add_tokens(list(new_tokens))
print('length after add:', len(tokenizerVI.vocab))
modelMaskedLM.resize_token_embeddings(len(tokenizerVI)+1)
Viết file những câu + token sẽ train để review.
# Open the file in write mode
with open("../data/news_output.txt", "w") as file:
# Truncate the file to the current position of the file pointer
file.truncate()
print(content_tokens[:20])
with open("../data/news_output.txt", "w") as file:
for t in content_tokens[:2]:
file.write(' '.join(t) + "\n")
tokenizerVI.add_special_tokens({'pad_token': '[PAD]'})
Encode thành list, seq lấy max=512
max_seq_length = 512
encoded_texts = [tokenizerVI.encode(text, truncation=True, max_length=max_seq_length) for text in contents]
print(encoded_texts[:2])
Train thôi
from transformers import LineByLineTextDataset, TextDataset, DataCollatorForLanguageModeling, Trainer, TrainingArguments
# training_data = LineByLineTextDataset(
# tokenizer=tokenizerVI,
# file_path='../data/news_output.txt',
# block_size=1024,
# )
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizerVI, mlm=False, mlm_probability=0.15
)
training_args = TrainingArguments(
output_dir="./results-text",
overwrite_output_dir=True,
num_train_epochs=12,
per_device_train_batch_size=16,
per_device_eval_batch_size=8,
warmup_steps=1000,
logging_steps=500,
)
trainer = Trainer(
model=modelMaskedLM,
args=training_args,
train_dataset=encoded_texts,
data_collator=data_collator
)
trainer.train()
trainer.save_model()
Và rồi xem thử model hoạt động thế nào
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
text = """
write comment about "Loved this", for restaurant
"""
input_ids = tokenizerVI.encode(text, return_tensors='pt')
print('input_ids',input_ids)
max_length = 100
sample_outputs = modelMaskedLM.generate(input_ids,pad_token_id=tokenizerVI.eos_token_id,
bos_token_id=tokenizerVI.bos_token_id,
eos_token_id=tokenizerVI.eos_token_id,
do_sample=True,
max_length=max_length,
min_length=max_length,
top_k=40,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=2,
num_return_sequences=3)
for i, sample_output in enumerate(sample_outputs):
print(">> Generated text {}\n\n{}".format(i+1, tokenizerVI.decode(sample_output.tolist())))
print('\n---')
- Kết quả cùng thầy zui zui, dù không chắc có hiệu quả không
Thank for reading
All rights reserved