譯者 | 朱先忠
審校 | 重樓
本文詳細介紹Transformer模型中控制文本生成的關鍵參數,包括溫度、Top-K和Top-P采樣、重復懲罰等,并探討這些參數對生成文本質量的影響及針對不同應用的調整方法。
Transformer模型是當今NLP任務的標準模型。幾乎所有NLP任務都涉及文本生成,但文本生成并非模型的直接輸出。你可能希望模型能夠幫助你生成連貫且與上下文相關的文本。雖然這在一定程度上與模型的質量有關,但生成參數也對生成文本的質量起著至關重要的作用。
在本文中,讓我們來一起探索控制Transformer模型中文本生成的關鍵參數。你將了解這些參數如何影響生成文本的質量,以及如何針對不同的應用進行調整。具體而言,你將學習到:
- Transformer模型中控制文本生成的核心參數
- 不同的解碼策略
- 如何控制生成文本的創造性和連貫性
- 如何針對特定應用微調生成參數
讓我們開始吧!
概述
本文將劃分為七個部分進行介紹,它們是:
- 核心文本生成參數
- 溫度實驗
- Top-K和Top-P采樣
- 控制重復
- 貪婪解碼和采樣
- 特定應用的參數
- 集束搜索和多序列生成
核心文本生成參數
我們以GPT-2模型為例。它是一個小型Transformer模型,不需要大量計算資源,但仍能生成高質量的文本。使用GPT-2模型生成文本的一個簡單示例如下:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
#創建模型和分詞器
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
#將輸入提示分詞為ID序列
prompt = "Artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
# 將輸出作為一系列標記ID生成
output = model.generate(
**inputs,
max_length=50,
num_return_sequences=1,
temperature=1.0,
top_k=50,
top_p=1.0,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
#將標記ID轉換為文本字符串
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Prompt: {prompt}")
print("Generated Text:")
print(generated_text)
如果運行此代碼,你可能會看到如下輸出內容:
Prompt: Artificial intelligence is
Generated Text:
Artificial intelligence is used in the production of technology, the delivery of
which is determined by technological change. For example, an autonomous car can
change its steering wheel to help avoid driving traffic. In the case of artificial
intelligence, this can change what consumers
本例中,你只提供了三個單詞的提示,模型就生成了一段很長的文本。這并非一次性生成,而是在迭代過程中多次調用模型。
你可以看到generate()函數中使用的眾多參數。你使用的第一個參數是max_length,它控制生成的文本的長度(以標記數量表示)。通常,模型使用提示作為上下文,一次生成一個標記。然后,將新生成的標記附加到提示中并生成下一個標記。因此,你希望生成的文本越長,生成它所需的時間就越長。請注意,這里關注的是標記,而不是單詞,因為你在GPT-2模型中使用了子詞標記器。一個標記可能只是一個子詞單元,而不是一個完整的單詞。
然而,該模型并非專門生成任何單個標記。相反,它生成一個“logit”,即下一個標記概率的向量。logit是一個長向量,恰好與詞匯表的大小相同。鑒于它是所有可能的“下一個標記”的概率分布,你可以選擇概率最高的標記(當設置do_sample=False時),或者任何其他概率非零的標記(當設置do_sample=True時)。這就是所有其他參數的目的。
temperature參數會扭曲概率分布。較低的溫度會強調最可能的標記,而較高的溫度會縮小可能的標記和不太可能的標記之間的差異。默認溫度為1.0,并且應為正值。然后,top_k參數僅選擇最靠前的標記標記,而不是整個標記詞匯表。然后重新計算概率,總和為1。接下來,如果設置了top_p,則這一組k個標記的集合進一步過濾,保留構成總概率p的那些頂級標記。然后使用這組最終的標記來對下一個標記進行采樣,這個過程稱為核采樣。
請記住,你正在生成一個標記序列,一次一個。你很可能會在每一步中重復看到相同的標記,并且你可能會在序列中看到相同的標記。這通常不是你想要的結果,因此你可能希望在再次看到這些標記時降低其出現的概率。這就是上面repetition_penalty參數的作用所在。
溫度實驗
假設到目前你已經知道了各個參數的作用,那么接下來,讓我們看看當你調整其中一些參數時輸出如何變化。
溫度參數對生成文本的創造性和隨機性有顯著的影響。你可以通過以下示例看到其效果:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
# 生成不同溫度值的文本
temperatures = [0.2, 0.5, 1.0, 1.5]
print(f"Prompt: {prompt}")
for temp in temperatures:
print()
print(f"Temperature: {temp}")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=temp,
top_k=50,
top_p=1.0,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
如果運行此代碼,你可能會看到如下輸出內容:
Prompt: The future of artificial intelligence is
Temperature: 0.2
Generated Text:
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future
Temperature: 0.5
Generated Text:
The future of artificial intelligence is uncertain.
"There is a lot of work to be done on this," said Eric Schmitt, a professor
of computer science and engineering at the University of California, Berkeley.
"We're looking for a way to make AI more like computers. We need to take a step
back and look at how we think about it and how we interact with it."
Schmitt said he's confident that artificial intelligence will eventually be
able to do more than
Temperature: 1.0
Generated Text:
The future of artificial intelligence is not yet clear, however."
"Is the process that we are trying to do through computer vision and the ability to
look at a person at multiple points without any loss of intelligence due to not
seeing a person at multiple points?" asked Richard. "I also think the people who
are doing this research are extremely interesting to me due to being able to see
humans at a range of different points in time. In particular, they've shown how
to do a pretty complex
Temperature: 1.5
Generated Text:
The future of artificial intelligence is an era to remember as much as Google in
search results, particularly ones not supported by much else for some years -- and
it might look like the search giant is now just as good without artificial
intelligence. [Graphic image from Shutterstock]
當溫度較低(例如0.2)時,文本會變得更加集中和確定,通常會堅持使用常用短語和傳統觀點。你還會看到,由于概率集中在少數幾個標記上,文本會不斷重復相同的句子,從而限制了多樣性。這個問題可以通過使用重復懲罰參數來解決,該參數將在下一節中介紹。
中等溫度(例如0.5到1.0)的文本在連貫性和創造性之間取得了良好的平衡。生成的文本可能并非基于事實,但語言自然。
當溫度較高(例如1.5)時,文本會變得更加隨意和富有創意,但也可能變得缺乏連貫性,有時甚至缺乏邏輯性。語言可能難以理解,就像上面的例子一樣。
選擇合適的溫度取決于你的應用。如果你正在創建代碼補全或寫作助手,通常較低的溫度更佳。對于創意寫作或頭腦風暴,較高的溫度可以產生更多樣化、更有趣的結果。
Top-K和Top-P采樣
核采樣參數控制著模型選擇下一個標記的靈活性。你應該調整top_k參數還是top_p參數?讓我們通過一個例子來看一下它們的效果:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The best way to learn programming is"
inputs = tokenizer(prompt, return_tensors="pt")
#使用不同top_k值生成文本
top_k_values = [5, 20, 50]
print(f"Prompt: {prompt}")
for top_k in top_k_values:
print()
print(f"Top-K = {top_k}")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=1.0,
top_k=top_k,
top_p=1.0,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
# 使用不同top_p值生成文本
top_p_values = [0.5, 0.7, 0.9]
for top_p in top_p_values:
print()
print(f"Top-P = {top_p}")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=1.0,
top_k=0,
top_p=top_p,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
如果運行此代碼,你可能會看到如下輸出內容:
Prompt: The best way to learn programming is
Top-K = 5
Generated Text:
The best way to learn programming is to be able to learn the basics in a very short
amount of time, and then learn to use them effectively and quickly.
If you want to be a successful programmer in this way, you should learn to use the
techniques in the above video to learn the basics of programming.
If you want to learn to code more effectively, you can also get more experienced
programmers by doing the following:
Learning to Code
Learning to code is very
Top-K = 20
Generated Text:
The best way to learn programming is to learn it.
In order to get started with Ruby you're going to have to make a few mistakes, some
of them can be fairly obvious.
First of all, you're going to have to write a function that takes in a value. What
this means is that you're going to make a new instance of the Ruby function. You can
read more about this in Part 1 of this course, or just try it out from the REPL.
Top-K = 50
Generated Text:
The best way to learn programming is to become familiar with the language and the
software. One of the first and most common forms of programming is to create,
modify, and distribute code.
However, there are very few programming libraries that can provide us with all
that we need.
The following sample programming program uses some of the above, but does not show
the best way to learn programming. It was written in Java and in C or C++.
The original source code is
Top-P = 0.5
Generated Text:
The best way to learn programming is to be able to create a tool for you. That's
what I do.
That's why I'm here today.
I'm here to talk about the basics of programming, and I'm going to tell you how to
learn programming.
I'm here to talk about learning programming.
It's easy to forget that you don't have to know how to program. It's easy to forget
that you don't have to know how
Top-P = 0.7
Generated Text:
The best way to learn programming is to practice programming. Learn the principles
of programming by observing and performing exercises.
I used to work in a world of knowledge which included all sorts of things, and was
able to catch up on them and understand them from their perspective. For instance, I
learned to sit up straight and do five squats. Then, I would have to practice some
type of overhead training. I would try to learn the best technique and add that to
my repertoire.
What
Top-P = 0.9
Generated Text:
The best way to learn programming is to become a good hacker. Don't use any
programming tools. Just a regular dot-com user, an occasional coding learner, and
stick with it.
— Victoria E. Nichols
你可以通過一個小的k值,例如5,看到模型可供選擇的選項較少,從而導致文本更可預測。在極端情況下,當k=1時,模型總是選擇概率最高的單個標記,這是貪婪解碼,通常會產生較差的輸出。當使用一個較大的k值,比如50,模型就有更多的選項可以選擇,從而產生更加多樣化的文本。
類似地,對于top_p參數,較小的p值意味著模型從一組較小的高概率標記中進行選擇,從而產生更有針對性的文本。使用較大的p值,例如0.9,則模型的選擇范圍更廣,可能會產生更多樣化的文本。但是,對于給定的文本,你可以選擇多少個選項并非固定不變,它取決于模型預測的概率分布。當模型對下一個標記非常有信心時(例如受某些語法規則限制),只允許使用非常小的標記集合。這種自適應特性也是為什么top-p采樣通常比top-k采樣更受歡迎的原因。
控制重復
重復是文本生成中常見的問題。repetition_penalty參數通過懲罰已在生成文本中出現過的標記來幫助解決這個問題。讓我們看看它是如何工作的:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "Once upon a time, there was a"
inputs = tokenizer(prompt, return_tensors="pt")
# 使用不同的重復懲罰生成文本
penalties = [1.0, 1.2, 1.5, 2.0]
print(f"Prompt: {prompt}")
for penalty in penalties:
print()
print(f"Repetition penalty: {penalty}")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=0.3,
top_k=50,
top_p=1.0,
repetition_penalty=penalty,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
如果運行此代碼,你可能會看到如下輸出內容:
Prompt: Once upon a time, there was a
Repetition penalty: 1.0
Generated Text:
Once upon a time, there was a great deal of confusion about what was going on. The
first thing that came to mind was the fact that the government had already been in
place for a long time, and that the government had been in place for a long time.
And it was clear that the government had been in place for a long time. And it was
clear that the government had been in place for a long time. And it was clear that
the government had been in place for a long
Repetition penalty: 1.2
Generated Text:
Once upon a time, there was a great deal of talk about the possibility that this
would be an opportunity for us to see more and better things in our lives. We had
been talking on Facebook all day long with people who were interested in what we
could do next or how they might help others find their own way out."
"We've always wanted to make sure everyone has access," he continued; "but it's not
like you can just go into your room at night looking around without seeing
Repetition penalty: 1.5
Generated Text:
Once upon a time, there was a man who had been called to the service of God. He
came and said: "I am an apostle from Jerusalem." And he answered him with great joy,
saying that it is not possible for me now in this life without having received
Jesus Christ as our Lord; but I will be saved through Him alone because my Father
has sent Me into all things by His Holy Spirit (John 1).
The Christian Church teaches us how much more than any other religion can
Repetition penalty: 2.0
Generated Text:
Once upon a time, there was a man who had been sent to the city of Nausicaa by his
father. The king's son and brother were killed in battle at that place; but when
he returned with them they found him dead on their way back from war-time.[1]
The King gave orders for an expedition against this strange creature called "the
Gorgon," which came out into space during one night after it attacked Earth[2]. It
is said that these creatures
在上面的代碼中,為了強調重復懲罰的效果,我們將溫度設置為0.3。當懲罰值較低(例如1.0)時,你可以看到模型一遍又一遍地重復同一個短語。當其他設置將候選標記限制在較小的子集時,模型很容易陷入循環。但是,當懲罰值較高(例如2.0或更高)時,模型會強烈避免重復,這有時會導致文本的自然性降低。中等懲罰值(例如1.2到1.5)通常是保持連貫性的良好折衷方案。
畢竟,generate()函數中設置的參數是為了保持文本自然流暢。你可能需要通過實驗來調整這些參數,以找到最適合你特定應用的參數。請注意,這些參數可能取決于你使用的模型,因為每個模型生成的標記可能具有不同的分布。
貪婪解碼和采樣
do_sample參數控制模型是使用采樣(基于概率選擇標記)還是貪婪解碼(始終選擇最可能的標記)。讓我們比較一下這兩種方法:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The secret to happiness is"
inputs = tokenizer(prompt, return_tensors="pt")
# 使用貪婪解碼與采樣生成文本
print(f"Prompt: {prompt}\n")
print("Greedy Decoding (do_sample=False):")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=1.0,
top_k=50,
top_p=1.0,
repetition_penalty=1.0,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
print()
print("Sampling (do_sample=True):")
output = model.generate(
**inputs,
max_length=100,
num_return_sequences=1,
temperature=1.0,
top_k=50,
top_p=1.0,
repetition_penalty=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
嘗試多次運行此代碼并觀察輸出結果。你會注意到,貪婪解碼的輸出始終相同,而采樣的輸出每次都不同。對于固定的提示,貪婪解碼是確定性的。該模型生成概率分布,并選擇最可能的標記,不涉及隨機性,輸出更有可能重復且無用。
采樣輸出是隨機的,因為輸出標記是根據模型預測的概率分布選擇的。這種隨機性使模型能夠生成更加多樣化和富有創意的文本;同時,只要其他生成參數設置得當,輸出仍然保持一致。在采樣輸出的情況下,你可以將num_return_sequences設置為大于1的數字,以便為同一提示并行生成多個序列。此參數對于貪婪解碼毫無意義。
特定應用的參數
對于特定的應用,應該設置哪些參數值?并沒有明確的答案。你肯定需要進行一些實驗來找到最佳組合。但是,你可以參考以下建議:
- 事實生成:
A.提供更低的temperature參數值(0.2至0.4)以獲得更確定的輸出
B.使用中等大小的top_p參數值(0.8到0.9),過濾掉不太可能的標記
C.使用更高的repetition_penalty參數值(1.2至1.5),以避免重復陳述
- 創意寫作:
A.提供更高一些的temperature參數值(1.0到1.3),可實現更具創意和多樣化的輸出
B.提供更高的top_p參數值(0.9到0.95),以提供更多可能性
C.提供較低的repetition_penalty參數值(1.0到1.1),以允許一些風格重復
- 代碼生成:
A.提供更低的temperature參數值(0.1到0.3),可獲得更精確、更正確的代碼
B.提供較低的top_p參數值(0.7至0.8),以關注最可能的標記C.提供更高的repetition_penalty參數值(1.3到1.5),以避免冗余代碼
- 對話生成:
A.提供中等大小的temperature參數值(0.6至0.8),反應自然但集中
B.提供中等大小的top_p參數值(0.9),創造力和連貫性達到良好平衡
C.提供中等大小的repetition_penalty參數值(1.2),避免重復的短語
請記住,語言模型并非完美的預言機,它也可能會出錯。上述參數旨在幫助你將生成過程與預期的輸出風格相匹配,但并不能保證其正確性。你得到的輸出可能包含錯誤。
集束搜索和多序列生成
在上面的例子中,生成過程是自回歸的。它是一個迭代過程,每次生成一個標記。
由于每個步驟都會通過采樣生成一個標記,因此你可以同時生成多個標記。這樣一來,你將為一個輸入提示生成多個輸出序列。理論上,如果你每一步生成k個標記,并且設置返回的長度為n,你將生成kn個序列。這個數字可能很大,你可能希望將其限制為幾個。
生成多個序列的第一種方法是設置num_return_sequences為數字k。你在第一步中生成k個標記。然后,完成每個標記的序列。這基本上確定了在生成中復制了提示k次。
第二種方法是使用集束搜索。這是一種生成多個序列的更復雜的方法。它會跟蹤最有希望的序列并并行探索它們。它不是生成kn個序列以淹沒記憶,它只保留每一步的最佳序列。每個標記生成步驟都會暫時擴展這個集合,然后將其修剪回最佳序列。
要使用集束搜索,你需要設置num_beams為一個數字k。每一步都會擴大k個序列以再添加一個標記,結果生成k2個序列,然后選擇最佳k個序列繼續下一步。你還可以通過設置early_stopping=True,以便在到達序列末尾時停止生成。你還應該設置num_return_sequences在輸出時限制最終選擇。
序列的選擇通常基于序列中標記的累積概率。但你也可以通過其他標準來調整選擇,例如添加長度懲罰或避免重復n-grams。以下是使用集束搜索的示例:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The key to successful machine learning is"
inputs = tokenizer(prompt, return_tensors="pt")
#使用貪婪解碼與采樣生成文本
print(f"Prompt: {prompt}\n")
outputs = model.generate(
**inputs,
num_beams=5, # 要使用的光束數量
early_stopping=True, # 當所有光束都完成時停止
no_repeat_ngram_size=2, # 避免重復n-gram
num_return_sequences=3, # 返回多個序列
max_length=100,
temperature=1.5,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
for idx, output in enumerate(outputs):
generated_text = tokenizer.decode(output, skip_special_tokens=True)
print(f"Generated Text ({idx+1}):")
print(generated_text)
你可以添加更多生成參數(例如length_penalty)來控制生成過程。上面的示例設置了更高的溫度,以突出集束搜索的輸出。運行此代碼,你可能會看到:
Prompt: The key to successful machine learning is
Generated Text (1):
The key to successful machine learning is to be able to learn from the world around
you. It is our job to make sure that we are learning from people, rather than just
from machines.
So, let's take a step back and look at how we can learn. Here's a list of the tools
we use to help us do that. We're going to go over a few of them here and give you
a general idea of what they are and how you can use them to create
Generated Text (2):
The key to successful machine learning is to be able to learn from the world around
you. It is our job to make sure that we are learning from people, rather than just
from machines.
So, let's take a step back and look at how we can learn. Here's a list of the tools
we use to help us do that. We're going to go over a few of them here and give you
a general idea of what they are and how you can use them and what
Generated Text (3):
The key to successful machine learning is to be able to learn from the world around
you. It is our job to make sure that we are learning from people, rather than just
from machines.
So, let's take a step back and look at how we can learn. Here's a list of the tools
we use to help us do that. We're going to go over a few of them here and give you
a general idea of what they are and how they work. You can use
輸出序列的數量仍然受num_return_sequences控制,但生成序列的過程使用了集束搜索算法。不過,從輸出結果很難判斷是否使用了集束搜索。一個跡象是,集束搜索的輸出不像單純的設置num_return_sequences那樣具有多樣性,因為生成的序列更多并且選擇了累積概率更高的序列。這種過濾確實降低了輸出的多樣性。
進一步閱讀
以下是一些你可能覺得有用的補充閱讀材料:
總結
在本文中,你了解了如何使用generate()函數中的眾多參數來控制生成過程。你可以調整這些參數,使輸出符合你應用程序的預期樣式。具體來說,你學習了:
- 如何利用溫度來控制輸出的概率分布
- 如何使用top-k和top-p來控制輸出的多樣性
- 如何使用重復懲罰、集束搜索和貪婪解碼來控制輸出
通過理解和調整這些參數,你可以優化不同應用的文本生成,從事實寫作到創意敘事、代碼生成和對話系統等各個領域。
譯者介紹
朱先忠,51CTO社區編輯,51CTO專家博客、講師,濰坊一所高校計算機教師,自由編程界老兵一枚。
原文標題:Understanding Text Generation Parameters in Transformers,作者:Muhammad Asad Iqbal Khan