微調大型語言模型進行命名實體識別
大型語言模型的目標是理解和生成與人類語言類似的文本。它們經過大規模的訓練,能夠對輸入的文本進行分析,并生成符合語法和語境的回復。這種模型可以用于各種任務,包括問答系統、對話機器人、文本生成、翻譯等。
命名實體識別(Named Entity Recognition,簡稱NER)是一種常見的應用方法,可以讓模型學會識別文本中的命名實體,如人名、地名、組織機構名等。
大型語言模型在訓練時通過大量的文本數據學習了豐富的語言結構和上下文信息。這使得模型能夠更好地理解命名實體在文本中的上下文,提高了識別的準確性。即使模型在訓練過程中沒有見過某個命名實體,它也可以通過上下文推斷該實體的類別。這意味著模型可以處理新的、未知的實體,而無需重新訓練。除此以外我們還能通過微調(fine-tuning)在特定任務上進行優化。這種遷移學習的方法使得在不同領域和任務上進行NER更加高效。
這篇文章總結了命名實體識別(NER)問題微調大型語言模型的經驗。我們將以個人身份信息(PII)為例來介紹大型語言模型進行NER微調的方法。
個人可識別信息(Personal Identifiable Information,PII)
個人可識別信息(Personal Identifiable Information,PII)是指可以用于識別、聯系或定位個人身份的數據或信息。這些信息可以單獨使用或結合其他信息,使得可以辨認特定的個人。PII通常包括但不限于以下內容:
- 全名
- 電子郵件地址
- 身份證號碼
- 駕駛證號碼
- 社會安全號碼
- 銀行賬號
- 生日
- 地址
這些信息的泄露可能會導致身份盜竊、個人隱私侵犯等問題,因此保護PII對于保障個人隱私和安全至關重要。
HIPAA隱私規定(Health Insurance Portability and Accountability Act Privacy Rules,簡稱HIPAA Privacy Rules)是一組法規,旨在保護醫療保健信息的隱私和安全。這些規定是由美國聯邦政府頒布,適用于醫療保健提供者、健康計劃、醫療支付者以及與這些實體交換醫療信息的其他組織和個人。
"Safe Harbor method" 是指在HIPAA(Health Insurance Portability and Accountability Act)隱私規定中的一種數據安全標準。這個方法允許醫療保健機構和其他涉及醫療信息的實體在某些條件下共享個人健康信息,而不會被認為違反HIPAA的隱私規定。
在Safe Harbor方法下,共享的個人健康信息必須經過匿名化處理,以使其不再能夠識別特定的個人。HIPAA規定了一組特定的標識符,包括但不限于以下信息:
- 醫療記錄號碼
- 社會安全號碼
- 駕駛證號碼
- 信用卡號碼
如果這些標識符被移除,或者通過某種方式使得個人健康信息無法與特定的個人相關聯,那么這些信息就被視為符合Safe Harbor標準。
所以NER就派上了用武之地,可以對LLM進行微調,生成包含檢測到的PII實體的結構良好的字符串,然后再進行匿名化處理來保證個人健康信息的安全性。
對大型語言模型進行微調
微調LLM主要有以下2個方面的挑戰:
調優的LLM不應該產生命名實體的幻覺。應該從一組受控的實體標簽中進行檢測。
微調LLM應該生成結構良好的輸出。LLM輸出不應包含無關信息(例如,解釋為什么檢測到某些實體)。因為輸出中的額外令牌導致每個輸入的推理成本更高。并且下游任務也無法使用。
所以,我們就先格式化訓練數據。
一個典型的NER輸入和輸出格式如下:
# INPUT
test_example = "My name is John Doe and I can be contacted at 111-222-3334"
# GROUND TRUTH NER DETECTIONS
test_detections = [
{
"entity_type": "PERSON",
"entity_value": "John Doe",
"start_position": 11,
"end_position": 19,
},
{
"entity_type": "PHONE_NUMBER",
"entity_value": "111-222-3334",
"start_position": 46,
"end_position": 58,
}
]
輸出數據可以通過多種方式進行格式化。對于典型的NER數據集,廣泛采用BIO格式。
BIO 格式是命名實體識別(Named Entity Recognition,NER)任務中常用的標注格式,用于標記文本中的命名實體。BIO 格式包括三種標記:B、I 和 O。
- B(Beginning):表示一個命名實體的開頭。
- I(Inside):表示一個命名實體的內部。
- O(Outside):表示不是命名實體的詞。
## BIO Tags for sentence - Alex is going to Los Angeles in California
Alex I-PER
is O
going O
to O
Los I-LOC
Angeles I-LOC
in O
California I-LOC
BIO格式是非常具體的。它需要理解標記是特定實體標簽的“內部”(I)和“外部”(O)。并且還要標識出實體標簽開始的字符,這會在制定LLM任務描述提示時增加不必要的復雜性。
我們可以嘗試了以下兩種輸出格式:
# JSON encoded string with NER detections
llm_output_str = "[{\"entity_type\": \"PERSON\",\"entity_value\": \"John Doe\",\"start_position\": 11,\"end_position\": 19,},{\"entity_type\": \"PHONE_NUMBER\",\"entity_value\": \"111-222-3334\",\"start_position\": 46,\"end_position\": 58}]"
這些字符串看起來與一般的NER數據完全相同。無需對LLM輸出進行任何額外的數據處理。并且我們可以直接使用json.loads(llm_output_str)來讀取模型的輸出。
但是這里我們必須要保證模型輸出必須為正確的JSON編碼字符串,并且還要記錄字符串在輸入的位置,這對于模型來說是有點困難的。
或者我們讓模型直接將NER的標簽進行標記,比如:
llm_output_str = "My name is <PERSON>John Doe</PERSON> and I can be contacted at <PHONE_NUMBER>111-222-3334</PHONE_NUMBER>"
讓模型在輸入中包含相關的<ENTITY_LABEL> </ENTITY_LABEL>標記,這樣對于我們查看結果是非常方便的,但是對于編碼來說還必須對LLM生成的輸出進行后期處理,解析檢測到的實體的實體以及開始和結束字符索引,這會增加我們的代碼量。并且這種方法我們需要保證在輸出時沒有任何令牌產生幻覺,而且輸入中的所有字符、標點和詞序都需要保留,這對于LLM來說也有一些困難。
那么我們選擇哪一個呢?在本文的最后,我們將看到哪種格式的輸出字符串效果更好。
提示設計
現在我們有了輸入和輸出數據格式,下面就需要設計一個向LLM描述任務的提示符。提示設計是非常非常重要的,這回影響到LLM的輸出。
《 QUANTIFYING LANGUAGE MODELS’ SENSITIVITY TO SPURIOUS FEATURES IN PROMPT DESIGN or: How I learned to start worrying about prompt formatting》論文討論了提示對于模型性能的變化,有興趣的可以看看。
對于任務描述,我們使用單獨提示來生成json和格式的輸出字符串。對于生成json字符串的模型,使用了以下任務描述:
You are given a user utterance that may contain Personal Identifiable
Information (PII). You are also given a list of entity types representing
personal identifiable information (PII). Your task is to detect and identify
all instances of the supplied PII entity types in the user utterance. Provide
a JSON output with keys: 'entity_type' (label of the detected entity),
'entity_value' (actual string value of the entity), 'start_position'
(start character index of the entity in the user utterance string), and
'end_position' (end character index of the entity in the user utterance string)
Ensure accuracy in identification of entities with correct start_position and
end_position character indices. Ensure that all entities are identified. Do
not perform false identifications.
對于輸出字符串,我使用了以下任務描述:
You are given a user utterance that may contain Personal Identifiable
Information (PII). You are also given a list of entity types representing
Personal Identifiable Information (PII). Your task is to detect and identify
all instances of the supplied PII entity types in the user utterance.
The output must have the same content as the input. Only the tokens that match
the PII entities in the list should be enclosed within XML tags. The XML tag
comes from the PII entities described in the list below. For example, a name
of a person should be enclosed within <PERSON></PERSON> tags. Ensure that all
entities are identified. Do not perform false identifications.
提示還需要包含實體類型及其描述的列表,以確保模型只檢測來自受控標簽列表的實體。我選擇了下面的模板:
List Of Entities
PERSON: Name of a person
Rx_NUMBER: Number identifying a medical prescription
ORDER_NUMBER: Number identifying a retail order
EMAIL_ADDRESS: Email address
PHONE_NUMBER: Telephone or mobile number
DATE_TIME: Dates and Times
US_SSN: Social Security Number in the United States
我們針對上面的提示進行以下的測試:
## Few shot example input
"My name is John Doe and I can be contacted at 111-222-3334"
## Few shot example output
"My name is <PERSON>John Doe</PERSON> and I can be contacted at <PHONE_NUMBER>111-222-3334</PHONE_NUMBER>"
## Actual input
"My phone number is 222-333-4445 and my name is Ana Jones"
## Incorrect Model output - model rephrases the output to match closer to few shot example output
"My name is <PERSON>Ana Jones</PERSON> and my phone number is <PHONE_NUMBER>222-333-4445</PHONE_NUMBER>"
## What model should have generated
"My phone number is <PHONE_NUMBER>222-333-4445</PHONE_NUMBER> and my name is <PERSON>Ana Jones</PERSON>"
我們可以非常明顯的看到模型輸出中生成虛假的標記,實體都已經區分出來了,但是位置變了。我們可以在prompt中加入一些少樣本的示例來讓模型強制學習,但是這回增加prompt的輸入令牌數。
在提示中加入Chain-Of-Thought
除了在會話中嵌入少樣本示例外,我們還可以讓模型以簡潔的方式重新描述指令。這加強了模型對任務的理解,可以獲得更好、更一致的格式化輸出。我還讓模型向我“解釋”,給定任務描述,為什么示例輸入和輸出是有意義的。比如說以下的提示:
# First user message
usr_msg1 = """
You are given a user utterance that may contain Personal Identifiable
Information (PII). You are also given a list of entity types representing
Personal Identifiable Information (PII). Your task is to detect and identify
all instances of the supplied PII entity types in the user utterance.
The output must have the same content as the input. Only the tokens that match
the PII entities in the list should be enclosed within XML tags. The XML tag
comes from the PII entities described in the list below. For example, a name
of a person should be enclosed within <PERSON></PERSON> tags. Ensure that all
entities are identified. Do not perform false identifications.
List Of Entities
PERSON: Name of a person
Rx_NUMBER: Number identifying a medical prescription
ORDER_NUMBER: Number identifying a retail order
EMAIL_ADDRESS: Email address
PHONE_NUMBER: Telephone or mobile number
DATE_TIME: Dates and Times
US_SSN: Social Security Number in the United States
Are the instructions clear to you?
"""
# First assistant message which is a reply to usr_msg1
# I hardcode this msg once the model produced a resonably good response.
# Note that the response comes from a non-fine-tuned version of the model.
# The idea is to sample a good 'thought' from the base version of the model.
asst_msg1 = """
Yes, the instructions are clear. I will identify and enclose within the
corresponding XML tags, all instances of the specified PII entity types in
the user utterance. For example, <PERSON><Name of a person></PERSON>,
<Rx_NUMBER><Number identifying a medical prescription></Rx_NUMBER>, etc.
leaving the rest of the user utterance unchanged.
"""
# Here I hardcode a few shot example as a user message
usr_msg2 = """
My name is John Doe, and can be contacted at 111-222-3334
"""
# Here I hardcode the appropriate response from the assitant as the
# correct output of the few shot example
asst_msg2 = """
My name is <PERSON>John Doe</PERSON>, and can be contacted at <PHONE_NUMBER>111-222-3334</PHONE_NUMBER>
"""
# Here I ask the assistant why its response to the last user message was the
# correct response
usr_msg3 = """
Give a brief explanation of why your answer is correct.
"""
# I hardcode the explaination from the model.
# Note that the explaination comes from a non-fine-tuned version of the model
asst_msg3 = """
I identified and enclosed within corresponding XML tags, all instances of
the specified PII entity types in the user utterance - a Person Name
"John Doe" within the <PERSON></PERSON> tag, and a Phone Number
"111-222-3334" within the <PHONE_NUMBER></PHONE_NUMBER> tag.
The rest of the user utterance was left unchanged as it did not contain any
other identified PII entities.
"""
# The rest of the user and assistant msgs are hardcoded so that the model is
# in a state where it expects another input from the user
usr_msg4 = """
Great! I am now going to give you another user utterance. Please
detect PII entities in it according to the previous instructions. Do
not include an explanation in your answer.
"""
asst_msg4 = """
Sure! Please give me the user utterance.
"""
# usr_msg5 would be the actual input string on which we want to detect the
# PII entities
創建完整訓練數據文件的樣例函數如下所示:
def get_fine_tune_prompt_xml(
rule_set: List[str],
input_str: str,
label_str: str,
tokenizer: PreTrainedTokenizerBase,
) -> torch.Tensor:
"""
Args:
rule_set (List[str]): List of strings representing entity labels and its
corresponding description
input_str (str): Actual input string on which detections need to be
performed
label_str (str): Expected output string corresponding to input_str
tokenizer (PreTrainedTokenizerBase): A tokenizer corresponding to the model
being fine-tuned
Returns:
torch.Tensor: Tensor of tokenized input ids
"""
rule_str = "\n".join(rule_set)
usr_msg1 = "You are given a user utterance that may contain Personal Identifiable Information (PII). " \
"You are also given a list of entity types representing Personal Identifiable Information (PII). " \
"Your task is to detect and identify all instances of the supplied PII entity types in the user utterance. " \
"The output must have the same content as the input. Only the tokens that match the PII entities in the " \
"list should be enclosed within XML tags. The XML tag comes from the PII entities described in the list below. " \
"For example, a name of a person should be enclosed within <PERSON></PERSON> tags." \
"Ensure that all entities are identified. Do not perform false identifications." \
f"""\n\nList Of Entities\n{rule_str}"""\
"\n\n" \
"Are the instructions clear to you?"
asst_msg1 = "Yes, the instructions are clear. I will identify and enclose within the corresponding XML tags, " \
"all instances of the specified PII entity types in the user utterance. For example, " \
"<PERSON><Name of a person></PERSON>, <Rx_NUMBER><Number identifying a medical prescription></Rx_NUMBER>, etc. " \
"leaving the rest of the user utterance unchanged."
usr_msg2 = "My name is John Doe, and can be contacted at 111-222-3334"
asst_msg2 = "My name is <PERSON>John Doe</PERSON>, and can be contacted at <PHONE_NUMBER>536-647-8464</PHONE_NUMBER>"
usr_msg3 = "Give a brief explanation of why your answer is correct."
asst_msg3 = "I identified and enclosed within corresponding XML tags, all instances of the specified PII " \
"entity types in the user utterance - a Person Name \"John Doe\" within the <PERSON></PERSON> tag, and " \
"a Phone Number \"536-647-8464\" within the <PHONE_NUMBER></PHONE_NUMBER> tag. The rest of the user " \
"utterance was left unchanged as it did not contain any other identified PII entities."
usr_msg4 = "Great! I am now going to give you another user utterance. Please detect PII entities in it " \
"according to the previous instructions. Do not include an explanation in your answer."
asst_msg4 = "Sure! Please give me the user utterance."
messages = [
{"role": "user", "content": usr_msg1},
{"role": "assistant", "content": asst_msg1},
{"role": "user", "content": usr_msg2},
{"role": "assistant", "content": asst_msg2},
{"role": "user", "content": usr_msg3},
{"role": "assistant", "content": asst_msg3},
{"role": "user", "content": usr_msg4},
{"role": "assistant", "content": asst_msg4},
{"role": "user", "content": input_str},
{"role": "assistant", "content": label_str},
]
encoded_input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
return encoded_input_ids
函數的輸出如下所示:
<s> [INST] You are given a user utterance that may contain Personal Identifiable Information (PII). You are also given a list of entity types representing Personal Identifiable Information (PII). Your task is to detect and identify all instances of the supplied PII entity types in the user utterance. The output must have the same content as the input. Only the tokens that match the PII entities in the list should be enclosed within XML tags. The XML tag comes from the PII entities described in the list below. For example, a name of a person should be enclosed within <PERSON></PERSON> tags. Ensure that all entities are identified. Do not perform false identifications.
List Of Entities
PERSON: Name of a person
Rx_NUMBER: Number identifying a medical prescription
ORDER_NUMBER: Number identifying a retail order
EMAIL_ADDRESS: Email address
PHONE_NUMBER: Telephone or mobile number
DATE_TIME: Dates and Times
US_SSN: Social Security Number in the United States
Are the instructions clear to you? [/INST]Yes, the instructions are clear. I will identify and enclose within the corresponding XML tags, all instances of the specified PII entity types in the user utterance. For example, <PERSON><Name of a person></PERSON>, <Rx_NUMBER><Number identifying a medical prescription></Rx_NUMBER>, etc. leaving the rest of the user utterance unchanged.</s> [INST] My name is John Doe, and can be contacted at 111-222-3334 [/INST]My name is <PERSON>John Doe</PERSON>, and can be contacted at <PHONE_NUMBER>111-222-3334</PHONE_NUMBER></s> [INST] Give a brief explanation of why your answer is correct. [/INST]I identified and enclosed within corresponding XML tags, all instances of the specified PII entity types in the user utterance - a Person Name "John Doe" within the <PERSON></PERSON> tag, and a Phone Number "111-222-3334" within the <PHONE_NUMBER></PHONE_NUMBER> tag. The rest of the user utterance was left unchanged as it did not contain any other identified PII entities.</s> [INST] Great! I am now going to give you another user utterance. Please detect PII entities in it according to the previous instructions. Do not include an explanation in your answer. [/INST]Sure! Please give me the user utterance.</s> [INST] Hi! is Dr. Danielle Boyd at the clinic [/INST]Hi! is Dr. <PERSON>Danielle Boyd</PERSON> at the clinic</s>
tokenizer.apply_chat_template()負責將' [INST] '和' [/INST] '應用于用戶消息,并將' </s> '(序列令牌結束)應用于輔助消息。還要注意,標記器負責將' <s> '(序列的開始)標記應用到提示符的開頭。這些微小的細節對模型在微調過程中是否能有效地學習和收斂有巨大的影響。
自定義損失
自回歸模型(像大多數llm一樣)被訓練來正確預測“下一個令牌”。給定我們剛剛創建的訓練數據樣本和微調訓練設置,模型將學習預測文本所有部分的下一個標記,即任務描述、實體列表、樣本示例、會話歷史中硬編碼的模型思維鏈等。
這將使模型除了學習預測正確的結果外,還學習任務描述中的令牌分布。這使得我們的學習任務有點繁瑣。我們對LLM進行微調的主要目標是為給定的輸入字符串生成結構良好且正確的檢測。因此,我們應該只計算輸出字符串中令牌的損失。所以在我們的示例訓練數據中,模型應該只計算以下令牌的損失。
Hi! is Dr. <PERSON>Danielle Boyd</PERSON> at the clinic</s>
這將鼓勵模型“忘記”之前的所有標記,只是“注意”主要標記并生成正確的輸出字符串。我們可以使用HuggingFace的DataCollator API。
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
@dataclass
class CustomDataCollatorWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
The tokenizer used for encoding the data.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
sequence is provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
return_tensors (`str`, *optional*, defaults to `"pt"`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
labels = batch["input_ids"].clone()
# Set loss mask for all pad tokens
labels[labels == self.tokenizer.pad_token_id] = -100
# Compute loss mask for appropriate tokens only
for i in range(batch['input_ids'].shape[0]):
# Decode the training input
text_content = self.tokenizer.decode(batch['input_ids'][i][1:]) # slicing from [1:] is important because tokenizer adds bos token
# Extract substrings for prompt text in the training input
# The training input ends at the last user msg ending in [/INST]
prompt_gen_boundary = text_content.rfind("[/INST]") + len("[/INST]")
prompt_text = text_content[:prompt_gen_boundary]
# print(f"""PROMPT TEXT:\n{prompt_text}""")
# retokenize the prompt text only
prompt_text_tokenized = self.tokenizer(
prompt_text,
return_overflowing_tokens=False,
return_length=False,
)
# compute index where prompt text ends in the training input
prompt_tok_idx = len(prompt_text_tokenized['input_ids'])
# Set loss mask for all tokens in prompt text
labels[i][range(prompt_tok_idx)] = -100
# print("================DEBUGGING INFORMATION===============")
# for idx, tok in enumerate(labels[i]):
# token_id = batch['input_ids'][i][idx]
# decoded_token_id = self.tokenizer.decode(batch['input_ids'][i][idx])
# print(f"""TOKID: {token_id} | LABEL: {tok} || DECODED: {decoded_token_id}""")
batch["labels"] = labels
return batch
CustomDataCollatorWithPadding類可以像下面這樣傳遞給SFTTrainer。
trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
eval_dataset=eval_dataset["val"],
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=packing,
# Using custom data collator inside SFTTrainer
data_collator=CustomDataCollatorWithPadding(
tokenizer=tokenizer,
padding="longest",
max_length=max_seq_length,
return_tensors="pt"
)
)
通過將標簽令牌id設置為-100,我們可以將這些令牌位置的損失歸零。這意味著從序列令牌的開頭(<s>)到最后一個用戶msg(以[/INST]結尾)的所有令牌都不會參與損失的計算。
結果
用這個設置微調了mistral /Mistral-7B-Instruct-v0.2模型。我有大約800個訓練數據樣本,大約400個測試樣本和大約400個驗證樣本。
訓練了3輪的模型,并在測試集上取得了相當高的精度/召回率/F1(96%以上)。
這里說一個結果,使用字符串標注的方法超過了生成JSON編碼的方法,雖然JSON的格式是正確的,但是正如我們前面所述的,在預測正確的' start_position '和' end_position '字符索引方面結果并不好。
這里我沒有確認模型是否也能很好地處理BIO輸出格式,我個人認為應該不會太好。
我們添加了自定義的損失掩碼,這是否有助于模型更好地泛化到看不見的實體?這個也沒有進行測試。
如果把7B的模型改為13B或者34B等更大的模型的性能如何變化?訓練和推理的成本是否值得性能的提升?這都是我們可以繼續研究的問題,如果你對NER感興趣可以自行研究,我也會在有結果后分享我的發現。