Giter Club home page Giter Club logo

packing-unpacking-pytorch-minimal-tutorial's Introduction

Minimal tutorial on packing and unpacking sequences in pytorch.

This is a fork from @Tushar-N 's gist. I have added comments and extra diagrams that should (hopefully) make it easier to understand. Forked Gist is here, but I prefer markdowns for tutorials ! :)

We want to run LSTM on a batch of 3 character sequences ['long_str', 'tiny', 'medium']. Here are the steps. You would be interested in *ed steps only.

  • Step 1: Construct Vocabulary
  • Step 2: Load indexed data (list of instances, where each instance is list of character indices)
  • Step 3: Make Model
  • Step 4: * Pad instances with 0s till max length sequence
  • Step 5: * Sort instances by sequence length in descending order
  • Step 6: * Embed the instances
  • Step 7: * Call pack_padded_sequence with embeded instances and sequence lengths
  • Step 8: * Forward with LSTM
  • Step 9: * Call unpack_padded_sequences if required / or just pick last hidden vector
  • * Summary of Shape Transformations

We want to run LSTM on a batch following 3 character sequences

seqs = ['long_str',  # len = 8
        'tiny',      # len = 4
        'medium']    # len = 6

Step 1: Construct Vocabulary

# make sure <pad> idx is 0
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))
# => ['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']

Step 2: Load indexed data (list of instances, where each instance is list of character indices)

vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]
# vectorized_seqs => [[6, 9, 8, 4, 1, 11, 12, 10],
#                     [12, 5, 8, 14],
#                     [7, 3, 2, 5, 13, 7]]

Step 3: Make Model

embed = Embedding(len(vocab), 4) # embedding_dim = 4
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5

Step 4: Pad instances with 0s till max length sequence

# get the length of each seq in your batch
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
# seq_lengths => [ 8, 4,  6]
# batch_sum_seq_len: 8 + 4 + 6 = 18
# max_seq_len: 8

seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()
# seq_tensor => [[0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]
#                [0 0 0 0 0 0 0 0]]

for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seqlen] = LongTensor(seq)
# seq_tensor => [[ 6  9  8  4  1 11 12 10]          # long_str
#                [12  5  8 14  0  0  0  0]          # tiny
#                [ 7  3  2  5 13  7  0  0]]         # medium
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

Step 5: Sort instances by sequence length in descending order

seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
# seq_tensor => [[ 6  9  8  4  1 11 12 10]           # long_str
#                [ 7  3  2  5 13  7  0  0]           # medium
#                [12  5  8 14  0  0  0  0]]          # tiny
# seq_tensor.shape : (batch_size X max_seq_len) = (3 X 8)

Step 6: Embed the instances

embedded_seq_tensor = embed(seq_tensor)
# embedded_seq_tensor =>
#                       [[[-0.77578706 -1.8080667  -1.1168439   1.1059115 ]     l
#                         [-0.23622951  2.0361056   0.15435742 -0.04513785]     o
#                         [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     n
#                         [ 0.40524676  0.98665565 -0.08621677 -1.1728264 ]     g
#                         [-1.6334635  -0.6100042   1.7509955  -1.931793  ]     _
#                         [-0.6470658  -0.6266589  -1.7463604   1.2675372 ]     s
#                         [ 0.64004815  0.45813003  0.3476034  -0.03451729]     t
#                         [-0.22739866 -0.45782727 -0.6643252   0.25129375]]    r

#                        [[ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m
#                         [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     e
#                         [ 0.48159844 -1.4886451   0.92639893  0.76906884]     d
#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i
#                         [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     u
#                         [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     m
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]    <pad>

#                        [[ 0.64004815  0.45813003  0.3476034  -0.03451729]     t
#                         [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     i
#                         [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     n
#                         [-1.284392    0.68294704  1.4064184  -0.42879772]     y
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]     <pad>
#                         [ 0.2691206  -0.43435425  0.87935454 -2.2269666 ]]]   <pad>
# embedded_seq_tensor.shape : (batch_size X max_seq_len X embedding_dim) = (3 X 8 X 4)

Step 7: Call pack_padded_sequence with embeded instances and sequence lengths

packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
# packed_input (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
# =>
#                         [[-0.77578706 -1.8080667  -1.1168439   1.1059115 ]     l
#                          [ 0.01795524 -0.59048957 -0.53800726 -0.6611691 ]     m
#                          [-0.6470658  -0.6266589  -1.7463604   1.2675372 ]     t
#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     o
#                          [ 0.40524676  0.98665565 -0.08621677 -1.1728264 ]     e
#                          [-1.284392    0.68294704  1.4064184  -0.42879772]     i
#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n
#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     d
#                          [ 0.64004815  0.45813003  0.3476034  -0.03451729]     n
#                          [-0.23622951  2.0361056   0.15435742 -0.04513785]     g
#                          [ 0.16031227 -0.08209462 -0.16297023  0.48121014]     i
#                          [-0.22739866 -0.45782727 -0.6643252   0.25129375]]    y
#                          [-0.7303265  -0.857339    0.58913064 -1.1068314 ]     _
#                          [-1.6334635  -0.6100042   1.7509955  -1.931793  ]     u
#                          [ 0.27616557 -1.224429   -1.342848   -0.7495876 ]     s
#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     m
#                          [-0.6000342   1.1732816   0.19938554 -1.5976517 ]     t
#                          [ 0.48159844 -1.4886451   0.92639893  0.76906884]     r
# : (batch_sum_seq_len X embedding_dim) = (18 X 4)
# packed_input.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1]
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

Step 8: Forward with LSTM

packed_output, (ht, ct) = lstm(packed_input)
# packed_output (PackedSequence is NamedTuple with 2 attributes: data and batch_sizes
# :
#                          [[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                           [ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   m
#                           [ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   t
#                           [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   o
#                           [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   e
#                           [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   i
#                           [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   n
#                           [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   d
#                           [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   n
#                           [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   g
#                           [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   i
#                           [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   y
#                           [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   _
#                           [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   u
#                           [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   s
#                           [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   m
#                           [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   t
#                           [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  r
# : (batch_sum_seq_len X hidden_dim) = (18 X 5)

# packed_output.batch_sizes => [ 3,  3,  3,  3,  2,  2,  1,  1] (same as packed_input.batch_sizes)
# visualization :
# l  o  n  g  _  s  t  r   #(long_str)
# m  e  d  i  u  m         #(medium)
# t  i  n  y               #(tiny)
# 3  3  3  3  2  2  1  1   (sum = 18 [batch_sum_seq_len])

Step 9: Call unpack_padded_sequences if required / or just pick last hidden vector

# unpack your output if required
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
# output:
# output =>
#                          [[[-0.00947162  0.07743231  0.20343193  0.29611713  0.07992904]   l
#                            [ 0.20994528  0.17932937  0.17748171  0.05025435  0.15717036]   o
#                            [ 0.09527554  0.14521319  0.1923058  -0.05925677  0.18633027]   n
#                            [ 0.14620507  0.07822411  0.2849248  -0.22616537  0.15480657]   g
#                            [ 0.01368293  0.15872964  0.03759198 -0.13403234  0.23890573]   _
#                            [ 0.00737647  0.17101538  0.28344846  0.18878219  0.20339936]   s
#                            [ 0.17885767  0.12713005  0.28287745  0.05562563  0.10871304]   t
#                            [ 0.09486895  0.12772645  0.34048414  0.25930756  0.12044918]]  r

#                           [[ 0.08596145  0.09205993  0.20892891  0.21788561  0.00624391]   m
#                            [ 0.01364102  0.11060348  0.14704391  0.24145307  0.12879576]   e
#                            [ 0.09872741  0.13324396  0.19446367  0.4307988  -0.05149471]   d
#                            [ 0.00884941  0.05762182  0.30557525  0.373712    0.08834908]   i
#                            [ 0.00377969  0.05943518  0.2961751   0.35107893  0.15148178]   u
#                            [ 0.0864429   0.11173367  0.3158251   0.37537992  0.11876849]   m
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]]  <pad>

#                           [[ 0.16861682  0.07807446  0.18812777 -0.01148055 -0.01091915]   t
#                            [ 0.02610307  0.00965587  0.31438383  0.246354    0.08276576]   i
#                            [ 0.03895474  0.08449443  0.18839942  0.02205326  0.23149511]   n
#                            [ 0.12460691  0.21189159  0.04823487  0.06384943  0.28563985]   y
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]   <pad>
#                            [ 0.          0.          0.          0.          0.        ]]] <pad>
# output.shape : ( batch_size X max_seq_len X hidden_dim) = (3 X 8 X 5)

# Or if you just want the final hidden state?

Summary of Shape Transformations

# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)
# (batch_size X max_seq_len X embedding_dim) --->      Pack     ---> (batch_sum_seq_len X embedding_dim)
# (batch_sum_seq_len X embedding_dim)        --->      LSTM     ---> (batch_sum_seq_len X hidden_dim)
# (batch_sum_seq_len X hidden_dim)           --->    UnPack     ---> (batch_size X max_seq_len X hidden_dim)

packing-unpacking-pytorch-minimal-tutorial's People


harshtrivedi avatar tushar-n avatar


 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar


 avatar  avatar  avatar  avatar  avatar

packing-unpacking-pytorch-minimal-tutorial's Issues

Mistake in the packed output explanation

This is the best explanation I have come across till now and I believe it is just a step away from being perfect. There is a mistake in the explanation. The order of outputs in packed_output format should be l,m,t,o,e,i,n,d,n,g,i,y,,u,s,m,t,r instead of l,o,n,g,,s,t,r,m,e,d,i,u,m,t,i,n,y. This can be verified by looking at the corresponding outputs in unpacked output.

Dimension mistake in step 5 comment

Hi, thanks for your writing.
sequence tensor shape after sorting by the sequence lengths is [batch_size, max_seq_len] => [3, 8]
but in the comment it is [3, 15]

dimension mistake in comment

# : (batch_sum_seq_len X hidden_dim) = (33 X 5)

I think it should be 18 X 5 instead of 33 X 5, please confirm it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.