ipsumdominum / pytorch-simple-transformer Goto Github PK
View Code? Open in Web Editor NEWA simple transformer implementation without difficult syntax and extra bells and whistles.
A simple transformer implementation without difficult syntax and extra bells and whistles.
The calculation of loss function is wicked and allows model to fit very quickly and produce very good translation. But unfortunately this is not real and the model is not able to spell out the German sentence itself. It could only completes an i-length German sentence if you give it the first (i-1) token, which means it could not generate the whole sentence from a start-of-sentence tag .
It seems reasonable enough to use this loss to train the network, but unreasonable to assess its translation ability, though I have yet to train this network to its full functionality.
###################
## original leave-last-token out decoder,
## Not sure what's the exact error in this calculation,
## but maybe because the model see the mask token directly?
##
#Output German, One Token At A Time
all_outs = torch.tensor([],requires_grad=True).to(device)
for i in range(item["german"].shape[1]-1):
out = model(item["german"][:,:i+1])
all_outs = torch.cat((all_outs,out),dim=1)
# ###################
# My variation of leave-last-token-out decoder, Used at training
# output_vocab_size = german_vocab_len
g = item["german"].shape
x = torch.zeros( [g[0],g[1],],dtype=torch.long ).to(device)
all_outs = torch.tensor([],requires_grad=True).to(device)
for i in range(item["german"].shape[1]-1):
xx = torch.zeros( [g[0],g[1], ],dtype=torch.long ).to(device)
out = model(x)
xx[:,i:i+1] = item["german"][:,i:i+1]
x = x+xx
all_outs = torch.cat((all_outs,out),dim=1)
# ###################
# My variation of beam search decoder
model.encode(item["english"][:,1:-1])
g = item["german"].shape
x = torch.zeros( [g[0],g[1],],dtype=torch.long ).to(device)
all_outs = torch.tensor([],requires_grad=True).to(device)
for i in range(item["german"].shape[1]-1):
out = model(x)
x[:,i:i+1] = out.argmax(axis=-1)
all_outs = torch.cat((all_outs,out),dim=1)
I found this glitch when fiddling with the attention layer at its core, and found zeroing the attention value created no harm to the performance of a last-token-only model in
sub_layers.py
attention_weights = F.softmax(attention_weights,dim=2)
attention_weights = attention_weights *0. ## Try this!
Dear devs,
I find this repo simple and smooth to run. However I am confused why you used the same embedding for input and output?
Specifically here in def fowward
and in def encode
, both used the same reference to self.embedding
. This looks weird isnt it? Since source language should use a different encoding when compared to the destination language?
class TransformerTranslator(nn.Module):
def __init__(self,embed_dim,num_blocks,num_heads,vocab_size,CUDA=False):
super(TransformerTranslator,self).__init__()
self.embedding = Embeddings(vocab_size,embed_dim,CUDA=CUDA)
self.encoder = Encoder(embed_dim,num_heads,num_blocks,CUDA=CUDA)
self.decoder = Decoder(embed_dim,num_heads,num_blocks,vocab_size,CUDA=CUDA)
self.encoded = False
self.device = torch.device('cuda:0' if CUDA else 'cpu')
def encode(self,input_sequence):
embedding = self.embedding(input_sequence).to(self.device)
self.encode_out = self.encoder(embedding)
self.encoded = True
def forward(self,output_sequence):
if(self.encoded==False):
print("ERROR::TransformerTranslator:: MUST ENCODE FIRST.")
return output_sequence
else:
embedding = self.embedding(output_sequence)
return self.decoder(self.encode_out,embedding)
Setting LOAD=50
does not reproduce the old loss using the checkpoint.pkl
Without a decoding method one cannot actually uses the trained network to translate... The greedy decoding requires a very good network.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.