更新時間:2022-04-26 來源:黑馬程序員 瀏覽量:
以一個符合語言規(guī)律的序列為輸入,模型將利用序列間關系等特征,輸出一個在所有詞匯上的概率分布.這樣的模型稱為語言模型。
# 語言模型的訓練語料一般來自于文章,對應的源文本和目標文本形如: src1 = "I can do" tgt1 = "can do it" src2 = "can do it", tgt2 = "do it <eos>"
1, 根據(jù)語言模型的定義,可以在它的基礎上完成機器翻譯,文本生成等任務,因為我們通過最后輸出的概率分布來預測下一個詞匯是什么.
2, 語言模型可以判斷輸入的序列是否為一句完整的話,因為我們可以根據(jù)輸出的概率分布查看最大概率是否落在句子結束符上,來判斷完整性.
3, 語言模型本身的訓練目標是預測下一個詞,因為它的特征提取部分會抽象很多語言序列之間的關系,這些關系可能同樣對其他語言類任務有效果.因此可以作為預訓練模型進行遷移學習.
整個案例的實現(xiàn)可分為以下五個步驟
第一步: 導入必備的工具包
第二步: 導入wikiText-2數(shù)據(jù)集并作基本處理
第三步: 構建用于模型輸入的批次化數(shù)據(jù)
第四步: 構建訓練和評估函數(shù)
第五步: 進行訓練和評估(包括驗證以及測試)
pytorch版本必須使用1.3.1, python版本使用3.6.x
pip install torch==1.3.1
# 數(shù)學計算工具包math import math # torch以及torch.nn, torch.nn.functional import torch import torch.nn as nn import torch.nn.functional as F # torch中經(jīng)典文本數(shù)據(jù)集有關的工具包 # 具體詳情參考下方torchtext介紹 import torchtext # torchtext中的數(shù)據(jù)處理工具, get_tokenizer用于英文分詞 from torchtext.data.utils import get_tokenizer # 已經(jīng)構建完成的TransformerModel from pyitcast.transformer import TransformerModel
torchtext:它是torch工具中處理NLP問題的常用數(shù)據(jù)處理包.
torchtext的重要功能:對文本數(shù)據(jù)進行處理, 比如文本語料加載, 文本迭代器構建等.
包含很多經(jīng)典文本語料的預加載方法. 其中包括的語料有:用于情感分析的SST和IMDB, 用于問題分類的TREC, 用于及其翻譯的 WMT14, IWSLT,以及用于語言模型任務wikiText-2, WikiText103, PennTreebank.
我們這里使用wikiText-2來訓練語言模型, 下面有關該數(shù)據(jù)集的相關詳情:
wikiText-2數(shù)據(jù)集的體量中等, 訓練集共有600篇短文, 共208萬左右的詞匯, 33278個不重復詞匯, OoV(有多少正常英文詞匯不在該數(shù)據(jù)集中的占比)為2.6%,數(shù)據(jù)集中的短文都是維基百科中對一些概念的介紹和描述.
# 創(chuàng)建語料域, 語料域是存放語料的數(shù)據(jù)結構, # 它的四個參數(shù)代表給存放語料(或稱作文本)施加的作用. # 分別為 tokenize,使用get_tokenizer("basic_english")獲得一個分割器對象, # 分割方式按照文本為基礎英文進行分割. # init_token為給文本施加的起始符 <sos>給文本施加的終止符<eos>, # 最后一個lower為True, 存放的文本字母全部小寫. TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"), init_token='<sos>', eos_token='<eos>', lower=True) # 最終獲得一個Field對象. # <torchtext.data.field.Field object at 0x7fc42a02e7f0> # 然后使用torchtext的數(shù)據(jù)集方法導入WikiText2數(shù)據(jù), # 并切分為對應訓練文本, 驗證文本,測試文本, 并對這些文本施加剛剛創(chuàng)建的語料域. train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT) # 我們可以通過examples[0].text取出文本對象進行查看. # >>> test_txt.examples[0].text[:10] # ['<eos>', '=', 'robert', '<unk>', '=', '<eos>', '<eos>', 'robert', '<unk>', 'is'] # 將訓練集文本數(shù)據(jù)構建一個vocab對象, # 這樣可以使用vocab對象的stoi方法統(tǒng)計文本共包含的不重復詞匯總數(shù). TEXT.build_vocab(train_txt) # 然后選擇設備cuda或者cpu device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
該案例的所有代碼都將實現(xiàn)在一個transformer_lm.py文件中.
批次化過程的第一個函數(shù)batchify代碼分析:
def batchify(data, bsz): """batchify函數(shù)用于將文本數(shù)據(jù)映射成連續(xù)數(shù)字, 并轉換成指定的樣式, 指定的樣式可參考下圖. 它有兩個輸入?yún)?shù), data就是我們之前得到的文本數(shù)據(jù)(train_txt, val_txt, test_txt), bsz是就是batch_size, 每次模型更新參數(shù)的數(shù)據(jù)量""" # 使用TEXT的numericalize方法將單詞映射成對應的連續(xù)數(shù)字. data = TEXT.numericalize([data.examples[0].text]) # >>> data # tensor([[ 3], # [ 12], # [3852], # ..., # [ 6], # [ 3], # [ 3]]) # 接著用數(shù)據(jù)詞匯總數(shù)除以bsz, # 取整數(shù)得到一個nbatch代表需要多少次batch后能夠遍歷完所有數(shù)據(jù) nbatch = data.size(0) // bsz # 之后使用narrow方法對不規(guī)整的剩余數(shù)據(jù)進行刪除, # 第一個參數(shù)是代表橫軸刪除還是縱軸刪除, 0為橫軸,1為縱軸 # 第二個和第三個參數(shù)代表保留開始軸到結束軸的數(shù)值.類似于切片 # 可參考下方演示示例進行更深理解. data = data.narrow(0, 0, nbatch * bsz) # >>> data # tensor([[ 3], # [ 12], # [3852], # ..., # [ 78], # [ 299], # [ 36]]) # 后面不能形成bsz個的一組數(shù)據(jù)被刪除 # 接著我們使用view方法對data進行矩陣變換, 使其成為如下樣式: # tensor([[ 3, 25, 1849, ..., 5, 65, 30], # [ 12, 66, 13, ..., 35, 2438, 4064], # [ 3852, 13667, 2962, ..., 902, 33, 20], # ..., # [ 154, 7, 10, ..., 5, 1076, 78], # [ 25, 4, 4135, ..., 4, 56, 299], # [ 6, 57, 385, ..., 3168, 737, 36]]) # 因為會做轉置操作, 因此這個矩陣的形狀是[None, bsz], # 如果輸入是訓練數(shù)據(jù)的話,形狀為[104335, 20], 可以通過打印data.shape獲得. # 也就是data的列數(shù)是等于bsz的值的. data = data.view(bsz, -1).t().contiguous() # 最后將數(shù)據(jù)分配在指定的設備上. return data.to(device)
batchify的樣式轉化圖:
注:大寫字母A,B,C ... 代表句子中的每個單詞.
torch.narrow演示:
>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> x.narrow(0, 0, 2) tensor([[ 1, 2, 3], [ 4, 5, 6]]) >>> x.narrow(1, 1, 2) tensor([[ 2, 3], [ 5, 6], [ 8, 9]])
接下來我們將使用batchify來處理訓練數(shù)據(jù),驗證數(shù)據(jù)以及測試數(shù)據(jù).
# 訓練數(shù)據(jù)的batch size batch_size = 20 # 驗證和測試數(shù)據(jù)(統(tǒng)稱為評估數(shù)據(jù))的batch size eval_batch_size = 10 # 獲得train_data, val_data, test_data train_data = batchify(train_txt, batch_size) val_data = batchify(val_txt, eval_batch_size) test_data = batchify(test_txt, eval_batch_size)
上面的分割批次并沒有進行源數(shù)據(jù)與目標數(shù)據(jù)的處理, 接下來我們將根據(jù)語言模型訓練的語料規(guī)定來構建源數(shù)據(jù)與目標數(shù)據(jù).
語言模型訓練的語料規(guī)定:
如果源數(shù)據(jù)為句子ABCD, ABCD代表句子中的詞匯或符號, 則它的目標數(shù)據(jù)為BCDE, BCDE分別代表ABCD的下一個詞匯.
如圖所示,我們這里的句子序列是豎著的, 而且我們發(fā)現(xiàn)如果用一個批次處理完所有數(shù)據(jù), 以訓練數(shù)據(jù)為例, 每個句子長度高達104335, 這明顯是不科學的, 因此我們在這里要限定每個批次中的句子長度允許的最大值bptt.
批次化過程的第二個函數(shù)get_batch代碼分析:
# 令子長度允許的最大值bptt為35 bptt = 35 def get_batch(source, i): """用于獲得每個批次合理大小的源數(shù)據(jù)和目標數(shù)據(jù). 參數(shù)source是通過batchify得到的train_data/val_data/test_data. i是具體的批次次數(shù). """ # 首先我們確定句子長度, 它將是在bptt和len(source) - 1 - i中最小值 # 實質上, 前面的批次中都會是bptt的值, 只不過最后一個批次中, 句子長度 # 可能不夠bptt的35個, 因此會變?yōu)閘en(source) - 1 - i的值. seq_len = min(bptt, len(source) - 1 - i) # 語言模型訓練的源數(shù)據(jù)的第i批數(shù)據(jù)將是batchify的結果的切片[i:i+seq_len] data = source[i:i+seq_len] # 根據(jù)語言模型訓練的語料規(guī)定, 它的目標數(shù)據(jù)是源數(shù)據(jù)向后移動一位 # 因為最后目標數(shù)據(jù)的切片會越界, 因此使用view(-1)來保證形狀正常. target = source[i+1:i+1+seq_len].view(-1) return data, target
輸入實例:
# 以測試集數(shù)據(jù)為例 source = test_data i = 1
輸出效果:
data = tensor([[ 12, 1053, 355, 134, 37, 7, 4, 0, 835, 9834], [ 635, 8, 5, 5, 421, 4, 88, 8, 573, 2511], [ 0, 58, 8, 8, 6, 692, 544, 0, 212, 5], [ 12, 0, 105, 26, 3, 5, 6, 0, 4, 56], [ 3, 16074, 21254, 320, 3, 262, 16, 6, 1087, 89], [ 3, 751, 3866, 10, 12, 31, 246, 238, 79, 49], [ 635, 943, 78, 36, 12, 475, 66, 10, 4, 924], [ 0, 2358, 52, 4, 12, 4, 5, 0, 19831, 21], [ 26, 38, 54, 40, 1589, 3729, 1014, 5, 8, 4], [ 33, 17597, 33, 1661, 15, 7, 5, 0, 4, 170], [ 335, 268, 117, 0, 0, 4, 3144, 1557, 0, 160], [ 106, 4, 4706, 2245, 12, 1074, 13, 2105, 5, 29], [ 5, 16074, 10, 1087, 12, 137, 251, 13238, 8, 4], [ 394, 746, 4, 9, 12, 6032, 4, 2190, 303, 12651], [ 8, 616, 2107, 4, 3, 4, 425, 0, 10, 510], [ 1339, 112, 23, 335, 3, 22251, 1162, 9, 11, 9], [ 1212, 468, 6, 820, 9, 7, 1231, 4202, 2866, 382], [ 6, 24, 104, 6, 4, 4, 7, 10, 9, 588], [ 31, 190, 0, 0, 230, 267, 4, 273, 278, 6], [ 34, 25, 47, 26, 1864, 6, 694, 0, 2112, 3], [ 11, 6, 52, 798, 8, 69, 20, 31, 63, 9], [ 1800, 25, 2141, 2442, 117, 31, 196, 7290, 4, 298], [ 15, 171, 15, 17, 1712, 13, 217, 59, 736, 5], [ 4210, 191, 142, 14, 5251, 939, 59, 38, 10055, 25132], [ 302, 23, 11718, 11, 11, 599, 382, 317, 8, 13], [ 16, 1564, 9, 4808, 6, 0, 6, 6, 4, 4], [ 4, 7, 39, 7, 3934, 5, 9, 3, 8047, 557], [ 394, 0, 10715, 3580, 8682, 31, 242, 0, 10055, 170], [ 96, 6, 144, 3403, 4, 13, 1014, 14, 6, 2395], [ 4, 3, 13729, 14, 40, 0, 5, 18, 676, 3267], [ 1031, 3, 0, 628, 1589, 22, 10916, 10969, 5, 22548], [ 9, 12, 6, 84, 15, 49, 3144, 7, 102, 15], [ 916, 12, 4, 203, 0, 273, 303, 333, 4318, 0], [ 6, 12, 0, 4842, 5, 17, 4, 47, 4138, 2072], [ 38, 237, 5, 50, 35, 27, 18530, 244, 20, 6]]) target = tensor([ 635, 8, 5, 5, 421, 4, 88, 8, 573, 2511, 0, 58, 8, 8, 6, 692, 544, 0, 212, 5, 12, 0, 105, 26, 3, 5, 6, 0, 4, 56, 3, 16074, 21254, 320, 3, 262, 16, 6, 1087, 89, 3, 751, 3866, 10, 12, 31, 246, 238, 79, 49, 635, 943, 78, 36, 12, 475, 66, 10, 4, 924, 0, 2358, 52, 4, 12, 4, 5, 0, 19831, 21, 26, 38, 54, 40, 1589, 3729, 1014, 5, 8, 4, 33, 17597, 33, 1661, 15, 7, 5, 0, 4, 170, 335, 268, 117, 0, 0, 4, 3144, 1557, 0, 160, 106, 4, 4706, 2245, 12, 1074, 13, 2105, 5, 29, 5, 16074, 10, 1087, 12, 137, 251, 13238, 8, 4, 394, 746, 4, 9, 12, 6032, 4, 2190, 303, 12651, 8, 616, 2107, 4, 3, 4, 425, 0, 10, 510, 1339, 112, 23, 335, 3, 22251, 1162, 9, 11, 9, 1212, 468, 6, 820, 9, 7, 1231, 4202, 2866, 382, 6, 24, 104, 6, 4, 4, 7, 10, 9, 588, 31, 190, 0, 0, 230, 267, 4, 273, 278, 6, 34, 25, 47, 26, 1864, 6, 694, 0, 2112, 3, 11, 6, 52, 798, 8, 69, 20, 31, 63, 9, 1800, 25, 2141, 2442, 117, 31, 196, 7290, 4, 298, 15, 171, 15, 17, 1712, 13, 217, 59, 736, 5, 4210, 191, 142, 14, 5251, 939, 59, 38, 10055, 25132, 302, 23, 11718, 11, 11, 599, 382, 317, 8, 13, 16, 1564, 9, 4808, 6, 0, 6, 6, 4, 4, 4, 7, 39, 7, 3934, 5, 9, 3, 8047, 557, 394, 0, 10715, 3580, 8682, 31, 242, 0, 10055, 170, 96, 6, 144, 3403, 4, 13, 1014, 14, 6, 2395, 4, 3, 13729, 14, 40, 0, 5, 18, 676, 3267, 1031, 3, 0, 628, 1589, 22, 10916, 10969, 5, 22548, 9, 12, 6, 84, 15, 49, 3144, 7, 102, 15, 916, 12, 4, 203, 0, 273, 303, 333, 4318, 0, 6, 12, 0, 4842, 5, 17, 4, 47, 4138, 2072, 38, 237, 5, 50, 35, 27, 18530, 244, 20, 6, 13, 1083, 35, 1990, 653, 13, 10, 11, 1538, 56])
設置模型超參數(shù)和初始化模型
# 通過TEXT.vocab.stoi方法獲得不重復詞匯總數(shù) ntokens = len(TEXT.vocab.stoi) # 詞嵌入大小為200 emsize = 200 # 前饋全連接層的節(jié)點數(shù) nhid = 200 # 編碼器層的數(shù)量 nlayers = 2 # 多頭注意力機制的頭數(shù) nhead = 2 # 置0比率 dropout = 0.2 # 將參數(shù)輸入到TransformerModel中 model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device) # 模型初始化后, 接下來進行損失函數(shù)和優(yōu)化方法的選擇. # 關于損失函數(shù), 我們使用nn自帶的交叉熵損失 criterion = nn.CrossEntropyLoss() # 學習率初始值定為5.0 lr = 5.0 # 優(yōu)化器選擇torch自帶的SGD隨機梯度下降方法, 并把lr傳入其中 optimizer = torch.optim.SGD(model.parameters(), lr=lr) # 定義學習率調整方法, 使用torch自帶的lr_scheduler, 將優(yōu)化器傳入其中. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
模型訓練代碼分析:
# 導入時間工具包 import time def train(): """訓練函數(shù)""" # 模型開啟訓練模式 model.train() # 定義初始損失為0 total_loss = 0. # 獲得當前時間 start_time = time.time() # 開始遍歷批次數(shù)據(jù) for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): # 通過get_batch獲得源數(shù)據(jù)和目標數(shù)據(jù) data, targets = get_batch(train_data, i) # 設置優(yōu)化器初始梯度為0梯度 optimizer.zero_grad() # 將數(shù)據(jù)裝入model得到輸出 output = model(data) # 將輸出和目標數(shù)據(jù)傳入損失函數(shù)對象 loss = criterion(output.view(-1, ntokens), targets) # 損失進行反向傳播以獲得總的損失 loss.backward() # 使用nn自帶的clip_grad_norm_方法進行梯度規(guī)范化, 防止出現(xiàn)梯度消失或爆炸 torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # 模型參數(shù)進行更新 optimizer.step() # 將每層的損失相加獲得總的損失 total_loss += loss.item() # 日志打印間隔定為200 log_interval = 200 # 如果batch是200的倍數(shù)且大于0,則打印相關日志 if batch % log_interval == 0 and batch > 0: # 平均損失為總損失除以log_interval cur_loss = total_loss / log_interval # 需要的時間為當前時間減去開始時間 elapsed = time.time() - start_time # 打印輪數(shù), 當前批次和總批次, 當前學習率, 訓練速度(每豪秒處理多少批次), # 平均損失, 以及困惑度, 困惑度是衡量語言模型的重要指標, 它的計算方法就是 # 對交叉熵平均損失取自然對數(shù)的底數(shù). print('| epoch {:3d} | {:5d}/{:5d} batches | ' 'lr {:02.2f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0], elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss))) # 每個批次結束后, 總損失歸0 total_loss = 0 # 開始時間取當前時間 start_time = time.time()
模型評估代碼分析:
def evaluate(eval_model, data_source): """評估函數(shù), 評估階段包括驗證和測試, 它的兩個參數(shù)eval_model為每輪訓練產(chǎn)生的模型 data_source代表驗證或測試數(shù)據(jù)集""" # 模型開啟評估模式 eval_model.eval() # 總損失歸0 total_loss = 0 # 因為評估模式模型參數(shù)不變, 因此反向傳播不需要求導, 以加快計算 with torch.no_grad(): # 與訓練過程相同, 但是因為過程不需要打印信息, 因此不需要batch數(shù) for i in range(0, data_source.size(0) - 1, bptt): # 首先還是通過通過get_batch獲得驗證數(shù)據(jù)集的源數(shù)據(jù)和目標數(shù)據(jù) data, targets = get_batch(data_source, i) # 通過eval_model獲得輸出 output = eval_model(data) # 對輸出形狀扁平化, 變?yōu)槿吭~匯的概率分布 output_flat = output.view(-1, ntokens) # 獲得評估過程的總損失 total_loss += criterion(output_flat, targets).item() # 計算平均損失 cur_loss = total_loss / ((data_source.size(0) - 1) / bptt) # 返回平均損失 return cur_loss
模型的訓練與驗證代碼分析:
# 首先初始化最佳驗證損失,初始值為無窮大 import copy best_val_loss = float("inf") # 定義訓練輪數(shù) epochs = 3 # 定義最佳模型變量, 初始值為None best_model = None # 使用for循環(huán)遍歷輪數(shù) for epoch in range(1, epochs + 1): # 首先獲得輪數(shù)開始時間 epoch_start_time = time.time() # 調用訓練函數(shù) train() # 該輪訓練后我們的模型參數(shù)已經(jīng)發(fā)生了變化 # 將模型和評估數(shù)據(jù)傳入到評估函數(shù)中 val_loss = evaluate(model, val_data) # 之后打印每輪的評估日志,分別有輪數(shù),耗時,驗證損失以及驗證困惑度 print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss))) print('-' * 89) # 我們將比較哪一輪損失最小,賦值給best_val_loss, # 并取該損失下的模型為best_model if val_loss < best_val_loss: best_val_loss = val_loss # 使用深拷貝,拷貝最優(yōu)模型 best_model = copy.deepcopy(model) # 每輪都會對優(yōu)化方法的學習率做調整 scheduler.step()
輸出效果:
| epoch 1 | 200/ 2981 batches | lr 5.00 | ms/batch 30.03 | loss 7.68 | ppl 2158.52 | epoch 1 | 400/ 2981 batches | lr 5.00 | ms/batch 28.90 | loss 5.26 | ppl 193.39 | epoch 1 | 600/ 2981 batches | lr 5.00 | ms/batch 28.90 | loss 4.07 | ppl 58.44 | epoch 1 | 800/ 2981 batches | lr 5.00 | ms/batch 28.88 | loss 3.41 | ppl 30.26 | epoch 1 | 1000/ 2981 batches | lr 5.00 | ms/batch 28.89 | loss 2.98 | ppl 19.72 | epoch 1 | 1200/ 2981 batches | lr 5.00 | ms/batch 28.90 | loss 2.79 | ppl 16.30 | epoch 1 | 1400/ 2981 batches | lr 5.00 | ms/batch 28.91 | loss 2.67 | ppl 14.38 | epoch 1 | 1600/ 2981 batches | lr 5.00 | ms/batch 28.92 | loss 2.58 | ppl 13.19 | epoch 1 | 1800/ 2981 batches | lr 5.00 | ms/batch 28.91 | loss 2.43 | ppl 11.32 | epoch 1 | 2000/ 2981 batches | lr 5.00 | ms/batch 28.92 | loss 2.39 | ppl 10.93 | epoch 1 | 2200/ 2981 batches | lr 5.00 | ms/batch 28.91 | loss 2.33 | ppl 10.24 | epoch 1 | 2400/ 2981 batches | lr 5.00 | ms/batch 28.91 | loss 2.36 | ppl 10.59 | epoch 1 | 2600/ 2981 batches | lr 5.00 | ms/batch 28.90 | loss 2.33 | ppl 10.31 | epoch 1 | 2800/ 2981 batches | lr 5.00 | ms/batch 28.92 | loss 2.26 | ppl 9.54 ----------------------------------------------------------------------------------------- | end of epoch 1 | time: 90.01s | valid loss 1.32 | valid ppl 3.73 ----------------------------------------------------------------------------------------- | epoch 2 | 200/ 2981 batches | lr 4.75 | ms/batch 29.08 | loss 2.18 | ppl 8.83 | epoch 2 | 400/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss 2.11 | ppl 8.24 | epoch 2 | 600/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss 1.98 | ppl 7.23 | epoch 2 | 800/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss 2.00 | ppl 7.39 | epoch 2 | 1000/ 2981 batches | lr 4.75 | ms/batch 28.94 | loss 1.94 | ppl 6.96 | epoch 2 | 1200/ 2981 batches | lr 4.75 | ms/batch 28.92 | loss 1.97 | ppl 7.15 | epoch 2 | 1400/ 2981 batches | lr 4.75 | ms/batch 28.94 | loss 1.98 | ppl 7.28 | epoch 2 | 1600/ 2981 batches | lr 4.75 | ms/batch 28.92 | loss 1.97 | ppl 7.16 | epoch 2 | 1800/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss 1.92 | ppl 6.84 | epoch 2 | 2000/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss 1.96 | ppl 7.11 | epoch 2 | 2200/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss 1.92 | ppl 6.80 | epoch 2 | 2400/ 2981 batches | lr 4.75 | ms/batch 28.94 | loss 1.94 | ppl 6.93 | epoch 2 | 2600/ 2981 batches | lr 4.75 | ms/batch 28.76 | loss 1.91 | ppl 6.76 | epoch 2 | 2800/ 2981 batches | lr 4.75 | ms/batch 28.75 | loss 1.89 | ppl 6.64 ----------------------------------------------------------------------------------------- | end of epoch 2 | time: 89.71s | valid loss 1.01 | valid ppl 2.74 ----------------------------------------------------------------------------------------- | epoch 3 | 200/ 2981 batches | lr 4.51 | ms/batch 28.88 | loss 1.78 | ppl 5.96 | epoch 3 | 400/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss 1.89 | ppl 6.59 | epoch 3 | 600/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss 1.72 | ppl 5.58 | epoch 3 | 800/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss 1.73 | ppl 5.63 | epoch 3 | 1000/ 2981 batches | lr 4.51 | ms/batch 28.73 | loss 1.65 | ppl 5.22 | epoch 3 | 1200/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss 1.69 | ppl 5.40 | epoch 3 | 1400/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss 1.73 | ppl 5.66 | epoch 3 | 1600/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss 1.75 | ppl 5.73 | epoch 3 | 1800/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss 1.67 | ppl 5.33 | epoch 3 | 2000/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss 1.69 | ppl 5.41 | epoch 3 | 2200/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss 1.66 | ppl 5.26 | epoch 3 | 2400/ 2981 batches | lr 4.51 | ms/batch 28.76 | loss 1.69 | ppl 5.43 | epoch 3 | 2600/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss 1.71 | ppl 5.55 | epoch 3 | 2800/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss 1.72 | ppl 5.58 ----------------------------------------------------------------------------------------- | end of epoch 3 | time: 89.26s | valid loss 0.85 | valid ppl 2.33
模型測試代碼分析:
# 我們?nèi)匀皇褂胑valuate函數(shù),這次它的參數(shù)是best_model以及測試數(shù)據(jù) test_loss = evaluate(best_model, test_data) # 打印測試日志,包括測試損失和測試困惑度 print('=' * 89) print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( test_loss, math.exp(test_loss))) print('=' * 89)
輸出效果:
========================================================================================= | End of training | test loss 0.83 | test ppl 2.30 =========================================================================================