Novice Blog

死ぬまでずっと途中だ!

PyTorchのチュートリアルをもとに英→日機械翻訳を試す

PyTorchが公式に提供しているチュートリアルの中に
Seq2Seq + Attentionを用いたニューラル機械翻訳がある.

Translation with a Sequence to Sequence Network and Attention — PyTorch Tutorials 0.4.1 documentation

今回はその素晴らしいチュートリアルをもとに
英→日機械翻訳を試してみた.

※ 基本的には公式のチュートリアルが素晴らしく充実しており,とても勉強になります.
 この記事は自身の後学のために公式のチュートリアルを簡潔に適宜修正したものになります.

なお,英日対訳コーパスodashiさんのものを使用させていただいた.

github.com


0. import宣言

import warnings
warnings.filterwarnings('ignore')
import re
import random
import unicodedata
import string

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

import numpy as np
from nltk import bleu_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.style.use('ggplot')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1. データ読み込み

lines = open('data/en-ja.txt', encoding='utf-8').read().strip().split('\n')
test_ratio = 0.1
train_lines, test_lines = lines[:int(len(lines)*(1-test_ratio))], lines[int(len(lines)*(1-test_ratio)):]

pairs = [l.split('\t') for l in lines]
training_pairs = [l.split('\t') for l in train_lines]
test_pairs = [l.split('\t') for l in test_lines]

input_sentences = [s[0] for s in pairs]
output_sentences = [s[1] for s in pairs]
input_sentences[:5]

["i can 't tell who will arrive first .",
'many animals have been destroyed by men .',
"i 'm in the tennis club .",
'emi looks happy .',
'please bear this fact in mind .']

output_sentences[:5]

['誰 が 一番 に 着 く か 私 に は 分か り ま せ ん 。',
'多く の 動物 が 人間 に よ っ て 滅ぼ さ れ た 。',
'私 は テニス 部員 で す 。',
'エミ は 幸せ そう に 見え ま す 。',
'この 事実 を 心 に 留め て お い て 下さ い 。']

2. 単語→ID化

SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "<s>", 1: "</s>"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
input_lang = Lang("en")
output_lang = Lang("ja")
for pair in pairs:
    input_lang.addSentence(pair[0])
    output_lang.addSentence(pair[1])
print("翻訳元コーパス({})の語彙数:{}".format(input_lang.name,input_lang.n_words))
print("翻訳先コーパス({})の語彙数:{}".format(output_lang.name,output_lang.n_words))

翻訳元コーパス(en)の語彙数:6636
翻訳先コーパス(ja)の語彙数:8776

input_sentences_wordlengths = [len(s.split()) for s in input_sentences]
output_sentences_wordlengths = [len(s.split()) for s in output_sentences]
print("翻訳元コーパス({})の最小文章長:{},最大文章長:{}".format(input_lang.name,min(input_sentences_wordlengths),max(input_sentences_wordlengths)))  
print("翻訳先コーパス({})の最小文章長:{},最大文章長:{}".format(output_lang.name,min(output_sentences_wordlengths),max(output_sentences_wordlengths)))

翻訳元コーパス(en)の最小文章長:4,最大文章長:16
翻訳先コーパス(ja)の最小文章長:4,最大文章長:16

MIN_LENGTH = 4 + 2  # <s>,</s>
MAX_LENGTH = 16 + 2 # <s>,</s>
for i in list(input_lang.index2word.keys())[:5]:
    print("{}:{}".format(i,input_lang.index2word[i]))
print("...")

0:<s>
1:</s>
2:i
3:can
4:'t
...

for i in list(output_lang.index2word.keys())[:5]:
    print("{}:{}".format(i,output_lang.index2word[i]))
print("...")

0:<s>
1:</s>
2:誰
3:が
4:一番
...


3. Seq2Seqモデル(Attention)

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.2, max_length=MAX_LENGTH):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

4. Training

hidden_size = 512
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.2).to(device)
encoder

EncoderRNN(
(embedding): Embedding(6636, 512)
(gru): GRU(512, 512)
)

decoder

AttnDecoderRNN(
(embedding): Embedding(8776, 512)
(attn): Linear(in_features=1024, out_features=18, bias=True)
(attn_combine): Linear(in_features=1024, out_features=512, bias=True)
(dropout): Dropout(p=0.2)
(gru): GRU(512, 512)
(out): Linear(in_features=512, out_features=8776, bias=True)
)

def tensorsFromPair(pair):
    input_sentence = pair[0]
    input_indexes = [input_lang.word2index[word] for word in input_sentence.split(' ')]
    input_indexes.append(EOS_token)
    input_tensor = torch.tensor(input_indexes, dtype=torch.long, device=device).view(-1, 1)
    
    output_sentence = pair[1]
    output_indexes = [output_lang.word2index[word] for word in output_sentence.split(' ')]
    output_indexes.append(EOS_token)
    output_tensor = torch.tensor(output_indexes, dtype=torch.long, device=device).view(-1, 1)

    return (input_tensor, output_tensor)
n_iters = 50000
learning_rate=0.01

encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
training_tensor_pairs = [tensorsFromPair(random.choice(training_pairs)) for i in range(n_iters)] # n_iters数分ランダムサンプリング
criterion = nn.NLLLoss()
print('< input tensor example >')
print(training_tensor_pairs[0][0])
print()
print('< output tensor example >')
print(training_tensor_pairs[0][1])

< input tensor example >
tensor([[ 2],
[600],
[257],
[ 42],
[141],
[258],
[ 10],
[ 1]], device='cuda:0')
 
< output tensor example >
tensor([[248],
[ 30],
[177],
[ 47],
[ 13],
[ 31],
[ 16],
[ 1]], device='cuda:0')

def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH, teacher_forcing_ratio = 1.0):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    loss = 0
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden = encoder_hidden # 最後の時刻の隠れ層を使用
    decoder_outputs = []

    # Teacher forcing
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    if use_teacher_forcing:
        # Use teacher forcing
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_outputs.append(decoder_output)
            loss += criterion(decoder_output, target_tensor[di])
            
            decoder_input = target_tensor[di]
    else:
        # Without teacher forcing
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_outputs.append(decoder_output)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            
            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length, decoder_outputs
def calc_bleu(target_tensor, decoder_tensor):
    refs = [[output_lang.index2word[t] for t in target_tensor.cpu().numpy()]]
    hyps = [output_lang.index2word[t] for t in decoder_tensor.cpu().numpy()]
    return 100 * bleu_score.sentence_bleu(refs, hyps)
print_every=1000
plot_every=200

print_loss_total = 0
print_bleu_total = 0
plot_loss_total = 0
plot_bleu_total = 0
loss_history = []
bleu_history = []
for iter in range(1, n_iters + 1):
    training_tensor_pair = training_tensor_pairs[iter-1]
    input_tensor = training_tensor_pair[0]
    target_tensor = training_tensor_pair[1]

    loss, decoder_outputs = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
    decoder_tensor = torch.tensor([i.topk(1)[1] for i in decoder_outputs])
    target_tensor = target_tensor.view(-1,target_tensor.shape[0])[0]
    
    bleu = calc_bleu(target_tensor, decoder_tensor)
    
    print_loss_total += loss
    print_bleu_total += bleu
    
    plot_loss_total += loss
    plot_bleu_total += bleu
    
    if iter % print_every == 0:
        print_loss_avg = print_loss_total / print_every
        print_bleu_avg = print_bleu_total / print_every
        print_loss_total = 0
        print_bleu_total = 0
        print('iter=%6d progress=%3d%% : loss= %.4f, bleu= %3f' % (iter, iter / n_iters * 100, print_loss_avg, print_bleu_avg))
        
    if iter % plot_every == 0:
        plot_loss_avg = plot_loss_total / plot_every
        plot_bleu_avg = plot_bleu_total / plot_every
        loss_history.append(plot_loss_avg)
        bleu_history.append(plot_bleu_avg)
        plot_loss_total = 0
        plot_bleu_total = 0

iter= 1000 progress= 2% : loss= 4.3269, bleu= 1.871565
iter= 2000 progress= 4% : loss= 3.6614, bleu= 4.311060
iter= 3000 progress= 6% : loss= 3.4342, bleu= 5.337739
iter= 4000 progress= 8% : loss= 3.2607, bleu= 7.349442
iter= 5000 progress= 10% : loss= 3.1763, bleu= 8.106195
iter= 6000 progress= 12% : loss= 3.0579, bleu= 8.004884
iter= 7000 progress= 14% : loss= 2.9973, bleu= 9.615947
iter= 8000 progress= 16% : loss= 2.8690, bleu= 10.995813
iter= 9000 progress= 18% : loss= 2.8181, bleu= 12.018116
iter= 10000 progress= 20% : loss= 2.7526, bleu= 12.019394
iter= 11000 progress= 22% : loss= 2.7032, bleu= 12.164422
iter= 12000 progress= 24% : loss= 2.5519, bleu= 14.942523
iter= 13000 progress= 26% : loss= 2.6454, bleu= 14.096671
iter= 14000 progress= 28% : loss= 2.5556, bleu= 14.323362
iter= 15000 progress= 30% : loss= 2.5182, bleu= 14.667089
iter= 16000 progress= 32% : loss= 2.4339, bleu= 15.752523
iter= 17000 progress= 34% : loss= 2.4273, bleu= 15.552970
iter= 18000 progress= 36% : loss= 2.3719, bleu= 17.833917
iter= 19000 progress= 38% : loss= 2.3144, bleu= 17.715252
iter= 20000 progress= 40% : loss= 2.3099, bleu= 18.025124
iter= 21000 progress= 42% : loss= 2.2757, bleu= 18.679683
iter= 22000 progress= 44% : loss= 2.2545, bleu= 19.556337
iter= 23000 progress= 46% : loss= 2.2239, bleu= 20.163076
iter= 24000 progress= 48% : loss= 2.2223, bleu= 18.929656
iter= 25000 progress= 50% : loss= 2.1114, bleu= 21.771670
iter= 26000 progress= 52% : loss= 2.1453, bleu= 19.853710
iter= 27000 progress= 54% : loss= 2.1209, bleu= 21.419971
iter= 28000 progress= 56% : loss= 2.0911, bleu= 21.514201
iter= 29000 progress= 57% : loss= 2.0828, bleu= 21.203588
iter= 30000 progress= 60% : loss= 2.1051, bleu= 21.578546
iter= 31000 progress= 62% : loss= 2.1040, bleu= 21.865763
iter= 32000 progress= 64% : loss= 2.0355, bleu= 21.806510
iter= 33000 progress= 66% : loss= 2.0578, bleu= 22.580487
iter= 34000 progress= 68% : loss= 2.0244, bleu= 22.418764
iter= 35000 progress= 70% : loss= 1.9706, bleu= 24.380544
iter= 36000 progress= 72% : loss= 1.9508, bleu= 25.039053
iter= 37000 progress= 74% : loss= 1.9714, bleu= 23.014639
iter= 38000 progress= 76% : loss= 1.9420, bleu= 24.265087
iter= 39000 progress= 78% : loss= 1.9473, bleu= 23.563979
iter= 40000 progress= 80% : loss= 1.9383, bleu= 23.739913
iter= 41000 progress= 82% : loss= 1.9311, bleu= 24.746721
iter= 42000 progress= 84% : loss= 1.9502, bleu= 23.907421
iter= 43000 progress= 86% : loss= 1.8771, bleu= 25.822816
iter= 44000 progress= 88% : loss= 1.8734, bleu= 25.263204
iter= 45000 progress= 90% : loss= 1.8831, bleu= 24.645873
iter= 46000 progress= 92% : loss= 1.8579, bleu= 25.760848
iter= 47000 progress= 94% : loss= 1.8771, bleu= 24.935368
iter= 48000 progress= 96% : loss= 1.8772, bleu= 24.196793
iter= 49000 progress= 98% : loss= 1.9398, bleu= 23.801817
iter= 50000 progress=100% : loss= 1.8306, bleu= 25.320334

x_loss = np.array(range(len(loss_history)))
x_bleu = np.array(range(len(bleu_history)))
fig, (axL, axR) = plt.subplots(ncols=2, figsize=(16,4))

axL.plot(x_loss*plot_every, loss_history, linewidth=1)
axL.set_title('loss')
axL.set_xlabel('iter')
axL.set_ylabel('loss')
axL.grid(True)

axR.plot(x_bleu*plot_every, bleu_history, linewidth=1)
axR.set_title('BLEU')
axR.set_xlabel('iter')
axR.set_ylabel('bleu')
axR.grid(True)

fig.show()

f:id:hightensan:20180911124037p:plain


5. Predict

def tensorFromSentence(lang, sentence):
    indexes = [lang.word2index[word] for word in sentence.split(' ')]
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
def predict(encoder, decoder, sentence, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei],
                                                     encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS

        decoder_hidden = encoder_hidden

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_attentions[di] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('</s>')
                break
            else:
                decoded_words.append(output_lang.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:di + 1]
def predictRandomly(encoder, decoder, n=10):
    for i in range(n):
        pair = random.choice(test_pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words, attentions = predict(encoder, decoder, pair[0])
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')
predictRandomly(encoder, decoder)

> all the students study english .
= その 学生 たち は 全員 英語 を 勉強 し て い ま す 。
< 学生 の 勉強 は 英語 の 勉強 を する こと に し て い る 。 \</s>
 
> please don 't be sad any more .
= これ 以上 悲し ま な い で 。
< もう これ 以上 は いけ な い 。 \</s>
 
> my father gave a nice watch to me .
= 私 の 父 は 私 に 素敵 な 時計 を くれ た 。
< 父 は 私 に 時計 を くれ た 。 \</s>
 
> translate this book into english .
= この 本 を 英語 に し なさ い 。
< この 本 を あなた は この 本 を 書 い た 。 \</s>
   
> he is unmarried .
= 彼 は 結婚 し て な い で す 。
< 彼 は 頭 が い い 。 \</s>
 
> are you satisfied with the result ?
= あなた は その 結果 に 満足 し て い ま す か 。
< あなた は その 結果 に 賛成 で す か 。 \</s>
 
> you never know what will happen tomorrow .
= 明日 何 が 起こ る か なんて だれ も わか ら な い 。
< 明日 は 何 を する こと が わか ら な い 。 \</s>
 
> my feet went to sleep and i could not stand up .
= 足 が しびれ て 立て な かっ た 。
< 私 の 計画 は 、 あまり 気 に は な い 。 \</s>

> i invited jane to dinner .
= 私 は 夕食 に ジェーン を 招待 し た 。
< ジェーン と ジェーン は 私 に 会 っ た 。 \</s>
 
> remember to post the letter .
= 忘れ ず に その 手紙 を 投函 し て くださ い 。
< 手紙 を 書 き なさ い 。 \</s>
 

 

 
 
notebook形式でも公開している.
https://gist.github.com/hightensan/23d791cefcfa1d12fe27fb1549ebd7eb
 
 
ご査証くださいませ.m(_ _)m