Transition-based dependency parsing

In this example, we implement a simplified version of the transition-based dependency parser described in the paper by Kiperwasser and Goldberg (2016).

In [1]:
import torch
from torch import nn
import time
import torchtext
import numpy as np

import random
import sys

from collections import defaultdict, Counter

import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina' 
plt.style.use('seaborn')

Reading the data

Let's first discuss the data and the way that it's formatted. We will use treebanks from the Universal Dependencies project.

We use the training and development sections of the English dataset. These files can be downloaded from the UD repository. You can read here about the format used in the UD project. Here is an example of a sentence in this format.

1       Now     now     ADV     RB      _                                 4       advmod  4:advmod
2       ,       ,       PUNCT   ,       _                                 4       punct   4:punct
3       people  people  NOUN    NNS     Number=Plur                       4       nsubj   4:nsubj
4       wonder  wonder  VERB    VBP     Mood=Ind|Tense=Pres|VerbForm=Fin  0       root    0:root
5       if      if      SCONJ   IN      _                                 9       mark    9:mark
6       Google  Google  PROPN   NNP     Number=Sing                       9       nsubj   9:nsubj
7       can     can     AUX     MD      VerbForm=Fin                      9       aux     9:aux
8       even    even    ADV     RB      _                                 9       advmod  9:advmod
9       survive survive VERB    VB      VerbForm=Inf                      4       ccomp   4:ccomp
10      .       .       PUNCT   .       _                                 4       punct   4:punct

To create torchtext Example objects, we extract columns corresponding to the word forms, part-of-speech tags, head positions and edge labels (columns 2, 5, 7 and 8 respectively). The former two will be used as inputs, the latter two as outputs predicted by the parser.

This code is almost identical to the code used last week, with one crucial difference: after reading each tree, we call the static oracle (defined below) to compute the correct sequence of parsing actions.

In [2]:
def read_data(corpus_file, datafields, transition_system, is_validation):
    with open(corpus_file, encoding='utf-8') as f:
        examples = []
        words = []
        postags = []
        heads = [-1]
        labels = ['<none>']
        n_failed = 0
        for line in f:
            if line[0] == '#': # Skip comments.
                continue    
            line = line.strip()
            if not line:
                # Blank line for the end of a sentence.
                labels.append('<none>')
                try:
                    actions, action_labels = transition_system.static_oracle(heads, labels)
                    examples.append(torchtext.data.Example.fromlist([words, postags, heads, labels, 
                                                                     actions, action_labels], datafields))
                except:
                    if is_validation:
                        examples.append(torchtext.data.Example.fromlist([words, postags, heads, labels,
                                                                         [], []], datafields))
                    n_failed += 1
                words = []
                postags = []
                heads = [-1]
                labels = ['<none>']
            else:
                columns = line.split('\t')
                # Skip dummy tokens used in ellipsis constructions, and multiword tokens.
                if '.' in columns[0] or '-' in columns[0]:
                    continue
                words.append(columns[1])
                postags.append(columns[4])
                heads.append(int(columns[6]))
                labels.append(columns[7])
        print(f'Read {len(examples)} sentences, {n_failed} ill-formed.')
        return torchtext.data.Dataset(examples, datafields)

Defining the transition system

The following classes defines the transition system that the parser will use, the arc-hybrid system. The most important pieces here are the static oracle that computes the correct sequence of actions to generate a tree from the training set, and the code that keeps track of the configurations for a batch of sentences (that is, the stacks, buffers, and generated edges).

In [3]:
class ArcHybridSystem:
    
    def __init__(self, n_stack_features=3):
        self.n_stack_features = n_stack_features
    
    def set_vocab(self, action_voc):
        self.la_enc = action_voc.stoi['la']
        self.ra_enc = action_voc.stoi['ra']
        self.sh_enc = action_voc.stoi['sh']
    
    def extract_action_features(self, actions):
        # This method finds the sentence positions ("features") considered while carrying
        # out the given sequence of actions.
        # At each step, the Kiperwasser and Goldberg parser considers 3 tokens in the stack and 1 in the buffer.
        action_features = []
        stack = [0]*self.n_stack_features
        buf_pos = 1
        for a in actions:
            action_features.append(stack[-self.n_stack_features:] + [buf_pos])
            if a == self.la_enc:
                stack.pop()
            elif a == self.ra_enc:
                stack.pop()
            elif a == self.sh_enc:
                stack.append(buf_pos)
                buf_pos += 1
            else:
                # dummy or end action
                pass
        return action_features
            
    def extract_action_features_batch(self, batch):
        return [self.extract_action_features(actions) for actions in batch]
        
    def static_oracle(self, gold_heads, gold_labels):
        # This method finds the sequence of actions required to generate a given tree.
        # It will return the sequence of actions, and a corresponding list of edge labels.
        # This method will throw an exception if the input is not a valid tree, or if the
        # tree is *non-projective*: that is, if it can't be drawn without crossing edges.
        
        n_tokens = len(gold_heads)
        
        stack = [0]
        buf_pos = 1
        
        gold_edges = defaultdict(list)
        for i, h in enumerate(gold_heads):
            if i >= 0:
                gold_edges[h].append(i)
    
        edges = defaultdict(list)
        
        actions = []
        action_labels = []
        
        while True:
            top = stack[-1]
            top_is_finished = len(edges[top]) == len(gold_edges[top])

            if buf_pos == n_tokens and len(stack) == 1:
                actions.append('end')
                action_labels.append('<none>')
                break
            
            elif top_is_finished and top in gold_edges[buf_pos]:
                actions.append('la')
                action_labels.append(gold_labels[top])
                edges[buf_pos].append(top)
                stack.pop()
            elif top_is_finished and len(stack) > 1 and top in gold_edges[stack[-2]]:
                actions.append('ra')
                action_labels.append(gold_labels[top])
                edges[stack[-2]].append(top)
                stack.pop()
            elif buf_pos < n_tokens:
                actions.append('sh')
                action_labels.append('<none>')
                stack.append(buf_pos)
                buf_pos += 1                
            else:
                # ill-formed graph
                # non-tree or non-projective
                raise Exception('ill-formed tree')

        # sanity-check
        heads = [-1]*(n_tokens)
        for h, deps in edges.items():
            for d in deps:
                heads[d] = h

        if gold_heads != heads:
            raise Exception('???')
            
        if len(actions) != len(action_labels):
            raise Exception('?')
            
        return actions, action_labels
    
    def init_parsing(self, words):
        return ArcHybridState(self, words)    
    

Here is a small example how we can call the static oracle to get the action sequence for a tree.

In [4]:
# For instance, "She lives in Gothenburg"

example_edges = [-1, 2, 0, 4, 2]
example_labels = [None, 'subj', 'root', 'prep', 'loc']

system = ArcHybridSystem()

system.static_oracle(example_edges, example_labels)
Out[4]:
(['sh', 'la', 'sh', 'sh', 'la', 'sh', 'ra', 'ra', 'end'],
 ['<none>',
  'subj',
  '<none>',
  '<none>',
  'prep',
  '<none>',
  'loc',
  'root',
  '<none>'])

The following class keeps track of the parser configurations for a batch of sentences. This is only used while running the parser, not while training it.

In [5]:
class ArcHybridState:
    
    def __init__(self, system, words):
        # Initializes parsing for a batch of sentences. This will create stacks just containing dummy root
        # tokens, buffer pointers pointing to the start of the sentences, and empty trees.
        n_sentences, n_words = words.shape
        self.n_stack_features = system.n_stack_features
        self.la_enc = system.la_enc
        self.ra_enc = system.ra_enc
        self.sh_enc = system.sh_enc
        
        self.n_words = n_words
        self.heads = torch.zeros_like(words)
        self.labels = torch.zeros_like(words)
        
        self.stacks = torch.zeros(size=(n_sentences, n_words+self.n_stack_features),
                                  device=words.device, dtype=torch.long)
        self.sp = torch.full((n_sentences,), self.n_stack_features-1, dtype=torch.long, device=words.device)
        self.bp = torch.full((n_sentences,), 1, dtype=torch.long, device=words.device)
        
        self.rows = torch.arange(n_sentences, dtype=torch.long, device=words.device)
        self.ranges = torch.arange(-self.n_stack_features+1, 1, 
                                   dtype=torch.long, device=words.device).view(-1, 1)

    def check_conditions(self):
        # Checks for each sentence whether the shift, right-arc or left-arc actions are applicable.
        # Returns tensors for the rows where the respective actions are not applicable.
        no_ra_rows = torch.where(self.sp == self.n_stack_features-1)[0]
        no_sh_rows = torch.where(self.bp == self.n_words-1)[0]
        no_la_rows = torch.where((self.bp == self.n_words-1) & (self.sp == self.n_stack_features-1))[0]
        return no_la_rows, no_ra_rows, no_sh_rows
        
    def extract_features(self):
        # Extracts the relevant sentence positions ("features") from the stacks and buffers
        # that we'll use to predict the next action for each sentence in the batch.
        # At each step, the Kiperwasser and Goldberg parser considers 3 tokens in the stack and 1 in the buffer.
        stack_ix = self.sp + self.ranges
        return torch.cat([self.stacks[self.rows, stack_ix].t(), self.bp.view(-1, 1)], dim=1)
        
    def update(self, actions, action_labels):
        # For a given batch of actions (and corresponding labels), update all the parser configurations.
        
        la_rows = torch.where(actions == self.la_enc)[0] 
        ra_rows = torch.where(actions == self.ra_enc)[0] 
        sh_rows = torch.where(actions == self.sh_enc)[0]

        spm1 = self.sp-1

        # Left arc:
        la_positions = self.stacks[la_rows, self.sp[la_rows]]
        # Edge from top-of-buffer to top-of-stack.
        self.heads[la_rows, la_positions] = self.bp[la_rows]
        self.labels[la_rows, la_positions] = action_labels[la_rows]
        # Pop an item from the stack.
        self.sp[la_rows] -= 1

        # Right arc:
        ra_positions = self.stacks[ra_rows, self.sp[ra_rows]]
        # Edge from second item in the stack to the first item.
        self.heads[ra_rows, ra_positions] = self.stacks[ra_rows, spm1[ra_rows]]
        self.labels[ra_rows, ra_positions] = action_labels[ra_rows]
        # Pop an item from the stack.
        self.sp[ra_rows] -= 1

        # Shift:
        # Put the first item in the buffer on top of the stack.
        self.sp[sh_rows] += 1
        self.stacks[sh_rows, self.sp[sh_rows]] = self.bp[sh_rows]
        # Move the buffer pointer one step forward.
        self.bp[sh_rows] += 1

Defining the classifier that selects actions and labels

We'll now define the neural network used in the parser by Kiperwasser and Goldberg (2016). The model consists of an encoder based on word and part-of-speech-tag embeddings, and a 3-layer BiLSTM. The outputs from the BiLSTM are then fed into the two classifiers that predict the next action and label. The classifiers are simple feedforward neural networks with one hidden layer and ReLU activations.

Drawing

In [6]:
class TransitionClassifier(nn.Module):
    
    def __init__(self, fields, word_emb_dim, pos_emb_dim,
                 rnn_size, rnn_depth, mlp_hidden_size, 
                 n_stack_features,
                 update_pretrained=False):
        super().__init__()
        
        word_field = fields[0][1]
        pos_field = fields[1][1]
        action_field = fields[4][1]
        label_field = fields[5][1]
        
        n_actions = len(action_field.vocab)
        n_labels = len(label_field.vocab)
        
        # Sentence encoder module.
        self.encoder = RNNEncoder(word_field, word_emb_dim, pos_field, pos_emb_dim, rnn_size, rnn_depth,
                                  update_pretrained)

        # MLPs for classifying actions and labels.
        mlp_input_size = 2*rnn_size*(1+n_stack_features)
        self.action_mlp = nn.Sequential(nn.Linear(mlp_input_size, mlp_hidden_size), 
                                        nn.ReLU(),
                                        nn.Linear(mlp_hidden_size, n_actions))        

        self.label_mlp = nn.Sequential(nn.Linear(mlp_input_size, mlp_hidden_size), 
                                       nn.ReLU(),
                                       nn.Linear(mlp_hidden_size, n_labels))        
        
        self.pad_id = action_field.vocab.stoi[action_field.pad_token]
        self.sh_id = action_field.vocab.stoi['sh']
        self.la_id = action_field.vocab.stoi['la']
        self.ra_id = action_field.vocab.stoi['ra']
        
        # Loss function that we will use during training.
        self.loss = torch.nn.CrossEntropyLoss(reduction='none')

    def word_tag_dropout(self, words, postags, p_drop):
        # Randomly replace some of the positions in the word and postag tensors with a zero.
        # This solution is a bit hacky because we assume that zero corresponds to the "unknown" token.
        w_dropout_mask = (torch.rand(size=words.shape, device=words.device) > p_drop).long()
        p_dropout_mask = (torch.rand(size=words.shape, device=words.device) > p_drop).long()
        return words*w_dropout_mask, postags*p_dropout_mask
        
    def forward(self, words, postags, actions, action_labels, action_features):

        if self.training:
            # If we are training, apply the word/tag dropout to the word and tag tensors.
            words, postags = self.word_tag_dropout(words, postags, 0.25)
        
        n_sentences, n_actions = actions.shape
        
        encoded = self.encoder(words, postags)
        action_input = self.encode_action_features(encoded, action_features)
        
        action_output = self.action_mlp(action_input)
        label_output = self.label_mlp(action_input)

        return self.compute_loss(actions, action_output) + self.compute_loss(action_labels, label_output)
    
    def compute_loss(self, actions, action_output):
        actions = actions.flatten()
        pad_mask = (actions != self.pad_id).float()
        action_loss = self.loss(action_output, actions)
        return action_loss.dot(pad_mask) / pad_mask.sum()
        
    def encode_action_features(self, encoded, action_features):
        n_sentences, n_actions, _ = action_features.shape
        row_ix = torch.arange(n_sentences, device=action_features.device, dtype=torch.long)
        a_ix = action_features.transpose(0, 2)
        out = encoded[row_ix, a_ix].transpose(0, 2)
        return out.reshape(n_sentences*n_actions, -1)
        
    def predict(self, encoded, action_features, no_la_rows, no_ra_rows, no_sh_rows):
        n_sentences, n_features = action_features.shape
        action_features = action_features.view(n_sentences, 1, n_features)
        action_input = self.encode_action_features(encoded, action_features)
        
        action_output = self.action_mlp(action_input)        
        label_output = self.label_mlp(action_input)
        
        action_output[no_sh_rows, self.sh_id] = -np.inf
        action_output[no_la_rows, self.la_id] = -np.inf
        action_output[no_ra_rows, self.ra_id] = -np.inf
        
        return action_output.argmax(dim=1), label_output.argmax(dim=1)

And here is the sentence encoding part. This is a straightforward application of techniques we've seen in the past, with the small twist that we're using embeddings not only for the words but also the part-of-speech tags.

In [7]:
class RNNEncoder(nn.Module):

    def __init__(self, word_field, word_emb_dim, pos_field, pos_emb_dim, rnn_size, rnn_depth, update_pretrained):
        super().__init__()
        
        self.word_embedding = nn.Embedding(len(word_field.vocab), word_emb_dim)
        if word_field.vocab.vectors is not None:
            self.word_embedding.weight = nn.Parameter(word_field.vocab.vectors, 
                                                       requires_grad=update_pretrained)
        self.pos_embedding = nn.Embedding(len(pos_field.vocab), pos_emb_dim)
        self.rnn = nn.LSTM(input_size=word_emb_dim+pos_emb_dim, hidden_size=rnn_size, batch_first=True,
                          bidirectional=True, num_layers=rnn_depth)
        
    def forward(self, words, postags):
        word_emb = self.word_embedding(words)
        pos_emb = self.pos_embedding(postags)
        word_pos_emb = torch.cat([word_emb, pos_emb], dim=2)

        rnn_out, _ = self.rnn(word_pos_emb)

        return rnn_out

Training the full system

As usual, we build a main function that loads the dataset, creates a model, and goes through the training loop, and prints some diagnostics at the end. This is similar to our previous examples, so we'll leave it without comment. The main difference is that there's a bit of additional overhead to compute the action sequences and to extract the relevant sentence positions.

While training, we print the unlabeled attachment score (UAS) and labeled attachment score (LAS) evaluated on the validation set. We usually reach UAS levels of about 0.88 when we use this English dataset, and LAS scores a bit lower.

At the end of the class, there are some auxiliary methods that call the action classifier sequentially and updates the parser configurations, depending on the selected actions.

In [8]:
class DependencyParser:
    
    def __init__(self, lower=False):
        bos = '<bos>'
        eos = '<eos>'
        none = '<none>'
        
        self.WORD = torchtext.data.Field(init_token=bos, eos_token=eos, sequential=True, 
                                         lower=lower, batch_first=True)
        self.POS = torchtext.data.Field(init_token=bos, eos_token=eos, sequential=True, 
                                        batch_first=True)
        self.HEAD = torchtext.data.Field(pad_token=-1, eos_token=-1, use_vocab=False, 
                                         sequential=True, batch_first=True)
        self.LABEL = torchtext.data.Field(pad_token=none, sequential=True,
                                          unk_token=None, batch_first=True)
        self.ACTION = torchtext.data.Field(pad_token=none, unk_token=None, sequential=True, batch_first=True)
                
        self.fields = [('words', self.WORD), ('postags', self.POS),
                       ('heads', self.HEAD), ('labels', self.LABEL),
                       ('actions', self.ACTION), ('action_labels', self.LABEL)]
        
        self.transition_system = ArcHybridSystem(n_stack_features=3)

        self.device = 'cuda'

        
    def train(self):
        
        torch.manual_seed(1234)
        random.seed(1234)
        
        # Read training and validation data according to the predefined split.
        train_examples = read_data('data/en_ewt-ud-train.conllu', self.fields, self.transition_system, False)
        val_examples = read_data('data/en_ewt-ud-dev.conllu', self.fields, self.transition_system, True)
                
        self.POS.build_vocab(train_examples)
        self.ACTION.build_vocab(train_examples)
        self.LABEL.build_vocab(train_examples)

        self.transition_system.set_vocab(self.ACTION.vocab)
        
        # Load the pre-trained word embeddings that come with the torchtext library.
        use_pretrained = True
        if use_pretrained:
            print('We are using pre-trained word embeddings.')
            self.WORD.build_vocab(train_examples, vectors="glove.840B.300d")
        else:  
            print('We are training word embeddings from scratch.')
            self.WORD.build_vocab(train_examples, max_size=10000)
        
        self.model = TransitionClassifier(self.fields, word_emb_dim=300, pos_emb_dim=32, 
                                      rnn_size=256, rnn_depth=3, mlp_hidden_size=256,
                                      n_stack_features=self.transition_system.n_stack_features,
                                      update_pretrained=False)
    
        self.model.to(self.device)
    
        train_iterator = torchtext.data.BucketIterator(
            train_examples,
            device=self.device,
            batch_size=64,
            sort_key=lambda x: len(x.words),
            repeat=False,
            train=True,
            sort=True)
        
        val_iterator = torchtext.data.BucketIterator(
            val_examples,
            device=self.device,
            batch_size=512,
            sort_key=lambda x: len(x.words),
            repeat=False,
            train=True,
            sort=True)

        train_batches = list(train_iterator)
        val_batches = list(val_iterator)
        
        train_action_features = []
        for batch in train_batches:
            batch_actions = batch.actions.cpu().numpy()
            action_features = self.transition_system.extract_action_features_batch(batch_actions)
            features_tensor = torch.as_tensor(action_features, device=self.device)
            train_action_features.append(features_tensor)
            
        val_action_features = []
        for batch in val_batches:
            batch_actions = batch.actions.cpu().numpy()
            action_features = self.transition_system.extract_action_features_batch(batch_actions)
            features_tensor = torch.as_tensor(action_features, device=self.device)
            val_action_features.append(features_tensor)
        
        train_batches = list(zip(train_batches, train_action_features))
        val_batches = list(zip(val_batches, val_action_features))
        
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0025, weight_decay=1e-5)

        history = defaultdict(list)    
        
        n_epochs = 25
        
        for i in range(1, n_epochs + 1):

            t0 = time.time()

            stats = Counter()
            
            random.shuffle(train_batches)
            
            self.model.train()
            for batch, batch_action_features in train_batches:
                
                loss = self.model(batch.words, batch.postags, batch.actions, batch.action_labels, batch_action_features)
                optimizer.zero_grad()            
                loss.backward()
                optimizer.step()
                stats['train_loss'] += loss.item()

                if self.device == 'cpu':
                    print('.', end='')
                    sys.stdout.flush()
                
            if self.device == 'cpu':
                print()

            train_loss = stats['train_loss'] / len(train_batches)
            history['train_loss'].append(train_loss)

            t1 = time.time()

            self.model.eval()
            with torch.no_grad():
                for batch, batch_action_features in val_batches:
                    loss = self.model(batch.words, batch.postags, batch.actions, batch.action_labels, batch_action_features)
                    stats['val_loss'] += loss.item()
                    predicted_heads, predicted_labels = self.parse_batch(batch)
                    n_tokens, n_corr_u, n_corr_l  = self.evaluate(batch.heads, batch.labels, predicted_heads, predicted_labels)
                    stats['val_n_tokens'] += n_tokens
                    stats['val_n_corr_u'] += n_corr_u
                    stats['val_n_corr_l'] += n_corr_l
            
            t2 = time.time()
            
            val_loss = stats['val_loss'] / len(val_batches)
            history['val_loss'].append(val_loss)
            uas = stats['val_n_corr_u']/stats['val_n_tokens']
            las = stats['val_n_corr_l']/stats['val_n_tokens']
            history['uas'].append(uas)
            history['las'].append(las)
            
            print(f'Epoch {i:2}: train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, UAS: {uas:.4f}, LAS: {las:.4f}, time: {t2-t0:.4f}')
            
        plt.plot(history['train_loss'])
        plt.plot(history['val_loss'])
        plt.plot(history['uas'])        
        plt.legend(['training loss', 'validation loss', 'UAS'])            

    def evaluate(self, gold_heads, gold_labels, predicted_heads, predicted_labels):
        # Computes the relevant counters for computing the LAS and UAS.
        pad_mask = (gold_heads != -1)
        head_ok = (gold_heads == predicted_heads)
        label_ok = (gold_labels == predicted_labels)
        
        n_corr_u = (pad_mask & head_ok).sum().item()
        n_corr_l = (pad_mask & head_ok & label_ok).sum().item()
        n_tokens = pad_mask.sum().item()
        
        return n_tokens, n_corr_u, n_corr_l
        
    def sanity_check(self, actions):
        n_sent, n_words, rnn_dim = self.encoded.shape
        shift_bug = (self.transition_system.bp >= n_words-1) & (actions == 1)
        if shift_bug.sum() > 0:
            print(self.transition_system.bp)
            print(actions)
                
    def step(self):
        # Carries out one parsing step: calling the action classifier, and then updating the configurations.
        action_features = self.state.extract_features()
        no_la, no_ra, no_sh = self.state.check_conditions()
        actions, labels = self.model.predict(self.state.encoded, action_features, 
                                             no_la, no_ra, no_sh)
        self.state.update(actions, labels)
        return actions, labels

    def steps(self):
        # Carries out parsing actions until all parsers have reached the end state.
        end_action_id = self.ACTION.vocab.stoi['end']
        while True:
            actions, labels = self.step()
            if (actions != end_action_id).sum().item() == 0:
                return
    
    def init_batch(self, batch):
        # Initializes the parser for a batch of sentences.
        self.state = self.transition_system.init_parsing(batch.words)
        self.state.encoded = self.model.encoder(batch.words, batch.postags)

    def parse_batch(self, batch):
        # Initializes and parses a batch of sentences, then returns the resulting trees.
        self.init_batch(batch)
        self.steps()
        return self.state.heads, self.state.labels
        
    def init(self, sentences):
        # Auxiliary method that is used when the input consists of word and part-of-speech strings.
        examples = []
        for tagged_words in sentences:
            words = [w for w, _ in tagged_words]
            tags = [t for _, t in tagged_words]
            examples.append(torchtext.data.Example.fromlist([words, tags, [], [], [], []], self.fields))
        dataset = torchtext.data.Dataset(examples, self.fields)
        
        iterator = torchtext.data.Iterator(
            dataset,
            device=self.device,
            batch_size=len(examples),
            repeat=False,
            train=False,
            sort=False)

        self.model.eval()
        with torch.no_grad():
            for batch in iterator:
                self.init_batch(batch)
                
                # hack: this is just to make the visualization a bit nicer
                _, n_heads = self.state.heads.shape
                self.state.heads += torch.arange(n_heads, dtype=torch.long, device=self.device)
        
    def get_heads_and_labels(self):
        heads = self.state.heads.cpu().numpy()
        label_enc = self.state.labels.cpu().numpy()
        labels = [[self.LABEL.vocab.itos[l] for l in row] for row in label_enc]
        return heads, labels
        
    def parse(self, sentences):
        self.init(sentences)
        self.steps()
        return self.get_heads_and_labels()
                    
parser = DependencyParser()
parser.train()
Read 11881 sentences, 662 ill-formed.
Read 2002 sentences, 55 ill-formed.
We are using pre-trained word embeddings.
Epoch  1: train loss: 2.2305, val loss: 0.6554, UAS: 0.6908, LAS: 0.6224, time: 4.1919
Epoch  2: train loss: 0.6278, val loss: 0.4797, UAS: 0.7739, LAS: 0.7341, time: 4.1456
Epoch  3: train loss: 0.4554, val loss: 0.2751, UAS: 0.8283, LAS: 0.8003, time: 4.1579
Epoch  4: train loss: 0.3891, val loss: 0.2500, UAS: 0.8371, LAS: 0.8114, time: 4.1197
Epoch  5: train loss: 0.3388, val loss: 0.2379, UAS: 0.8454, LAS: 0.8208, time: 4.1388
Epoch  6: train loss: 0.3112, val loss: 0.2272, UAS: 0.8525, LAS: 0.8279, time: 4.1524
Epoch  7: train loss: 0.2935, val loss: 0.2163, UAS: 0.8545, LAS: 0.8322, time: 4.2315
Epoch  8: train loss: 0.2708, val loss: 0.2212, UAS: 0.8661, LAS: 0.8427, time: 4.1539
Epoch  9: train loss: 0.2707, val loss: 0.2044, UAS: 0.8722, LAS: 0.8510, time: 4.3478
Epoch 10: train loss: 0.2703, val loss: 0.2134, UAS: 0.8595, LAS: 0.8374, time: 4.3487
Epoch 11: train loss: 0.2510, val loss: 0.2194, UAS: 0.8618, LAS: 0.8392, time: 4.1293
Epoch 12: train loss: 0.2405, val loss: 0.1976, UAS: 0.8756, LAS: 0.8557, time: 4.2061
Epoch 13: train loss: 0.2193, val loss: 0.1990, UAS: 0.8740, LAS: 0.8531, time: 4.2627
Epoch 14: train loss: 0.2147, val loss: 0.2012, UAS: 0.8780, LAS: 0.8566, time: 4.3424
Epoch 15: train loss: 0.2152, val loss: 0.2019, UAS: 0.8738, LAS: 0.8524, time: 4.3898
Epoch 16: train loss: 0.2044, val loss: 0.2087, UAS: 0.8788, LAS: 0.8554, time: 4.5325
Epoch 17: train loss: 0.1889, val loss: 0.2021, UAS: 0.8770, LAS: 0.8566, time: 4.5338
Epoch 18: train loss: 0.1891, val loss: 0.2081, UAS: 0.8703, LAS: 0.8496, time: 4.9837
Epoch 19: train loss: 0.1944, val loss: 0.2062, UAS: 0.8789, LAS: 0.8582, time: 5.0521
Epoch 20: train loss: 0.1937, val loss: 0.2268, UAS: 0.8714, LAS: 0.8508, time: 4.8444
Epoch 21: train loss: 0.1779, val loss: 0.2170, UAS: 0.8797, LAS: 0.8583, time: 4.6722
Epoch 22: train loss: 0.1715, val loss: 0.2244, UAS: 0.8784, LAS: 0.8575, time: 4.7570
Epoch 23: train loss: 0.1724, val loss: 0.2097, UAS: 0.8836, LAS: 0.8632, time: 4.8230
Epoch 24: train loss: 0.1696, val loss: 0.2096, UAS: 0.8808, LAS: 0.8583, time: 4.7526
Epoch 25: train loss: 0.1665, val loss: 0.2131, UAS: 0.8871, LAS: 0.8661, time: 4.5509

Interactive demo

The following interactive demo requires that you have NLTK and graphviz installed, as in the notebook from last week.

In [9]:
import nltk

# Download the tokenizer and part-of-speech tagger models if you haven't done it before.
# nltk.download('punkt')
# nltk.download('averaged_perceptron_tagger')

import warnings
warnings.filterwarnings('ignore')

# Put the directory where 'dot' is located first in the PATH.
import os
os.environ['PATH'] = '/opt/miniconda3/bin:' + os.environ['PATH']

The following class is used for the interactive demo.

In [19]:
class ParserDemo:
    
    def __init__(self, sentence):
        self.tokenized = nltk.word_tokenize(sentence)
        self.tagged = nltk.pos_tag(self.tokenized)    
        parser.init([self.tagged])
        self.show_state()
        
    def show_state(self):
        sp = parser.state.sp[0].item()
        stack = parser.state.stacks[0, :sp+1].cpu().numpy()
        print('Stack:', stack)
        print('Buffer position:', parser.state.bp[0].item())
        
    def step(self):        
        actions, action_labels = parser.step()
        actions = [ parser.ACTION.vocab.itos[a] for a in actions.cpu().numpy() ]
        action_labels = [ parser.LABEL.vocab.itos[l] for l in action_labels.cpu().numpy() ]
        if actions[0] not in ['la', 'ra']:
            print('Selected action:', actions[0])
        else:
            print(f'Selected action and label: {actions[0]} {action_labels[0]}')
        self.show_state()
            
    def draw_tree(self):
        heads, labels = parser.get_heads_and_labels()   
        nltk_str = '\n'.join(f'{w} _ {h} {l}' for (w, h, l) in zip(self.tokenized, heads[0][1:], labels[0][1:]))
        return nltk.DependencyGraph(nltk_str)

demo = ParserDemo('She lives in a house in my garden')
demo.draw_tree()
Stack: [0 0 0]
Buffer position: 1
Out[19]:
G 0 0 (None) 1 1 (She) 1->1 <none> 2 2 (lives) 2->2 <none> 3 3 (in) 3->3 <none> 4 4 (a) 4->4 <none> 5 5 (house) 5->5 <none> 6 6 (in) 6->6 <none> 7 7 (my) 7->7 <none> 8 8 (garden) 8->8 <none>

If you run the following cell repeatedly, you'll see how the parser selects actions and updates the configuration and edges. Eventually, you'll get a complete tree and the parser will go into the end state.

In [31]:
step = demo.step()
demo.draw_tree()
Selected action and label: la case
Stack: [0 0 0 2 5]
Buffer position: 8
Out[31]:
G 0 0 (None) 1 1 (She) 2 2 (lives) 2->1 nsubj 2->2 <none> 3 3 (in) 4 4 (a) 5 5 (house) 5->3 case 5->4 det 5->5 <none> 6 6 (in) 7 7 (my) 8 8 (garden) 8->6 case 8->7 nmod:poss 8->8 <none>

Parsing example sentences

Here is an example how we can tokenize and tag a sentence using NLTK:

In [12]:
nltk.pos_tag(nltk.word_tokenize('The big dog lives in its little house.'))
Out[12]:
[('The', 'DT'),
 ('big', 'JJ'),
 ('dog', 'NN'),
 ('lives', 'VBZ'),
 ('in', 'IN'),
 ('its', 'PRP$'),
 ('little', 'JJ'),
 ('house', 'NN'),
 ('.', '.')]

We make a utility function that calls NLTK to tokenize and part-of-speech-tag the sentence, and then call the parser we trained above to find the head position for each word. We then print the words, tags and head positions as three separate columns.

In [13]:
def parse_sentence(sentence):
    tokenized = nltk.word_tokenize(sentence)
    tagged = nltk.pos_tag(tokenized)    
    heads, labels = parser.parse([tagged])
    for i, ((word, tag), head, label) in enumerate(zip(tagged, heads[0][1:], labels[0][1:]), 1):
        print(f'{i:2} {word:10} {tag:4} {head} {label}')

Here is the result of parsing an example sentence.

In [14]:
parse_sentence('The big dog lives in its little house.')
 1 The        DT   3 det
 2 big        JJ   3 amod
 3 dog        NN   4 nsubj
 4 lives      VBZ  0 root
 5 in         IN   8 case
 6 its        PRP$ 8 nmod:poss
 7 little     JJ   8 amod
 8 house      NN   4 obl
 9 .          .    4 punct

Drawing dependency trees

It's probably a bit more understandable to look at the parse trees visually than the column-based format above. NLTK also includes functionality to draw trees.

Drawing the trees in a notebook requires the installation of graphviz utilities. If you use pip to install packages, install this package; if you use conda, install this and this.

When running this example, I had to add the location of the dot utility to the PATH. This might or might not be necessary in your case.

In [15]:
def draw_sentence(sentence):
    tokenized = nltk.word_tokenize(sentence)
    tagged = nltk.pos_tag(tokenized)    
    heads, labels = parser.parse([tagged]) 
    nltk_str = '\n'.join(f'{w} _ {h} {l}' for (w, h, l) in zip(tokenized, heads[0][1:], labels[0][1:]))
    return nltk.DependencyGraph(nltk_str)

We can now draw a tree for the example we saw above.

In [16]:
draw_sentence('The big dog lives in its little house.')
Out[16]:
G 0 0 (None) 4 4 (lives) 0->4 root 3 3 (dog) 4->3 nsubj 8 8 (house) 4->8 obl 9 9 (.) 4->9 punct 1 1 (The) 2 2 (big) 3->1 det 3->2 amod 5 5 (in) 8->5 case 6 6 (its) 8->6 nmod:poss 7 7 (little) 8->7 amod
In [17]:
draw_sentence('"You\'ve made a big mess," said my mother.')
Out[17]:
G 0 0 (None) 10 10 (said) 0->10 root 1 1 (``) 10->1 punct 4 4 (made) 10->4 ccomp 8 8 (,) 10->8 punct 9 9 ('') 10->9 punct 13 13 (.) 10->13 punct 12 12 (mother) 10->12 nsubj 2 2 (You) 3 3 ('ve) 4->2 nsubj 4->3 aux 7 7 (mess) 4->7 obj 5 5 (a) 7->5 det 6 6 (big) 7->6 amod 11 11 (my) 12->11 nmod:poss