In this example, we implement a simplified version of the transition-based dependency parser described in the paper by Kiperwasser and Goldberg (2016).
import torch
from torch import nn
import time
import torchtext
import numpy as np
import random
import sys
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
plt.style.use('seaborn')
Let's first discuss the data and the way that it's formatted. We will use treebanks from the Universal Dependencies project.
We use the training and development sections of the English dataset. These files can be downloaded from the UD repository. You can read here about the format used in the UD project. Here is an example of a sentence in this format.
1 Now now ADV RB _ 4 advmod 4:advmod
2 , , PUNCT , _ 4 punct 4:punct
3 people people NOUN NNS Number=Plur 4 nsubj 4:nsubj
4 wonder wonder VERB VBP Mood=Ind|Tense=Pres|VerbForm=Fin 0 root 0:root
5 if if SCONJ IN _ 9 mark 9:mark
6 Google Google PROPN NNP Number=Sing 9 nsubj 9:nsubj
7 can can AUX MD VerbForm=Fin 9 aux 9:aux
8 even even ADV RB _ 9 advmod 9:advmod
9 survive survive VERB VB VerbForm=Inf 4 ccomp 4:ccomp
10 . . PUNCT . _ 4 punct 4:punct
To create torchtext Example
objects, we extract columns corresponding to the word forms, part-of-speech tags, head positions and edge labels (columns 2, 5, 7 and 8 respectively). The former two will be used as inputs, the latter two as outputs predicted by the parser.
This code is almost identical to the code used last week, with one crucial difference: after reading each tree, we call the static oracle (defined below) to compute the correct sequence of parsing actions.
def read_data(corpus_file, datafields, transition_system, is_validation):
with open(corpus_file, encoding='utf-8') as f:
examples = []
words = []
postags = []
heads = [-1]
labels = ['<none>']
n_failed = 0
for line in f:
if line[0] == '#': # Skip comments.
continue
line = line.strip()
if not line:
# Blank line for the end of a sentence.
labels.append('<none>')
try:
actions, action_labels = transition_system.static_oracle(heads, labels)
examples.append(torchtext.data.Example.fromlist([words, postags, heads, labels,
actions, action_labels], datafields))
except:
if is_validation:
examples.append(torchtext.data.Example.fromlist([words, postags, heads, labels,
[], []], datafields))
n_failed += 1
words = []
postags = []
heads = [-1]
labels = ['<none>']
else:
columns = line.split('\t')
# Skip dummy tokens used in ellipsis constructions, and multiword tokens.
if '.' in columns[0] or '-' in columns[0]:
continue
words.append(columns[1])
postags.append(columns[4])
heads.append(int(columns[6]))
labels.append(columns[7])
print(f'Read {len(examples)} sentences, {n_failed} ill-formed.')
return torchtext.data.Dataset(examples, datafields)
The following classes defines the transition system that the parser will use, the arc-hybrid system. The most important pieces here are the static oracle that computes the correct sequence of actions to generate a tree from the training set, and the code that keeps track of the configurations for a batch of sentences (that is, the stacks, buffers, and generated edges).
class ArcHybridSystem:
def __init__(self, n_stack_features=3):
self.n_stack_features = n_stack_features
def set_vocab(self, action_voc):
self.la_enc = action_voc.stoi['la']
self.ra_enc = action_voc.stoi['ra']
self.sh_enc = action_voc.stoi['sh']
def extract_action_features(self, actions):
# This method finds the sentence positions ("features") considered while carrying
# out the given sequence of actions.
# At each step, the Kiperwasser and Goldberg parser considers 3 tokens in the stack and 1 in the buffer.
action_features = []
stack = [0]*self.n_stack_features
buf_pos = 1
for a in actions:
action_features.append(stack[-self.n_stack_features:] + [buf_pos])
if a == self.la_enc:
stack.pop()
elif a == self.ra_enc:
stack.pop()
elif a == self.sh_enc:
stack.append(buf_pos)
buf_pos += 1
else:
# dummy or end action
pass
return action_features
def extract_action_features_batch(self, batch):
return [self.extract_action_features(actions) for actions in batch]
def static_oracle(self, gold_heads, gold_labels):
# This method finds the sequence of actions required to generate a given tree.
# It will return the sequence of actions, and a corresponding list of edge labels.
# This method will throw an exception if the input is not a valid tree, or if the
# tree is *non-projective*: that is, if it can't be drawn without crossing edges.
n_tokens = len(gold_heads)
stack = [0]
buf_pos = 1
gold_edges = defaultdict(list)
for i, h in enumerate(gold_heads):
if i >= 0:
gold_edges[h].append(i)
edges = defaultdict(list)
actions = []
action_labels = []
while True:
top = stack[-1]
top_is_finished = len(edges[top]) == len(gold_edges[top])
if buf_pos == n_tokens and len(stack) == 1:
actions.append('end')
action_labels.append('<none>')
break
elif top_is_finished and top in gold_edges[buf_pos]:
actions.append('la')
action_labels.append(gold_labels[top])
edges[buf_pos].append(top)
stack.pop()
elif top_is_finished and len(stack) > 1 and top in gold_edges[stack[-2]]:
actions.append('ra')
action_labels.append(gold_labels[top])
edges[stack[-2]].append(top)
stack.pop()
elif buf_pos < n_tokens:
actions.append('sh')
action_labels.append('<none>')
stack.append(buf_pos)
buf_pos += 1
else:
# ill-formed graph
# non-tree or non-projective
raise Exception('ill-formed tree')
# sanity-check
heads = [-1]*(n_tokens)
for h, deps in edges.items():
for d in deps:
heads[d] = h
if gold_heads != heads:
raise Exception('???')
if len(actions) != len(action_labels):
raise Exception('?')
return actions, action_labels
def init_parsing(self, words):
return ArcHybridState(self, words)
Here is a small example how we can call the static oracle to get the action sequence for a tree.
# For instance, "She lives in Gothenburg"
example_edges = [-1, 2, 0, 4, 2]
example_labels = [None, 'subj', 'root', 'prep', 'loc']
system = ArcHybridSystem()
system.static_oracle(example_edges, example_labels)
The following class keeps track of the parser configurations for a batch of sentences. This is only used while running the parser, not while training it.
class ArcHybridState:
def __init__(self, system, words):
# Initializes parsing for a batch of sentences. This will create stacks just containing dummy root
# tokens, buffer pointers pointing to the start of the sentences, and empty trees.
n_sentences, n_words = words.shape
self.n_stack_features = system.n_stack_features
self.la_enc = system.la_enc
self.ra_enc = system.ra_enc
self.sh_enc = system.sh_enc
self.n_words = n_words
self.heads = torch.zeros_like(words)
self.labels = torch.zeros_like(words)
self.stacks = torch.zeros(size=(n_sentences, n_words+self.n_stack_features),
device=words.device, dtype=torch.long)
self.sp = torch.full((n_sentences,), self.n_stack_features-1, dtype=torch.long, device=words.device)
self.bp = torch.full((n_sentences,), 1, dtype=torch.long, device=words.device)
self.rows = torch.arange(n_sentences, dtype=torch.long, device=words.device)
self.ranges = torch.arange(-self.n_stack_features+1, 1,
dtype=torch.long, device=words.device).view(-1, 1)
def check_conditions(self):
# Checks for each sentence whether the shift, right-arc or left-arc actions are applicable.
# Returns tensors for the rows where the respective actions are not applicable.
no_ra_rows = torch.where(self.sp == self.n_stack_features-1)[0]
no_sh_rows = torch.where(self.bp == self.n_words-1)[0]
no_la_rows = torch.where((self.bp == self.n_words-1) & (self.sp == self.n_stack_features-1))[0]
return no_la_rows, no_ra_rows, no_sh_rows
def extract_features(self):
# Extracts the relevant sentence positions ("features") from the stacks and buffers
# that we'll use to predict the next action for each sentence in the batch.
# At each step, the Kiperwasser and Goldberg parser considers 3 tokens in the stack and 1 in the buffer.
stack_ix = self.sp + self.ranges
return torch.cat([self.stacks[self.rows, stack_ix].t(), self.bp.view(-1, 1)], dim=1)
def update(self, actions, action_labels):
# For a given batch of actions (and corresponding labels), update all the parser configurations.
la_rows = torch.where(actions == self.la_enc)[0]
ra_rows = torch.where(actions == self.ra_enc)[0]
sh_rows = torch.where(actions == self.sh_enc)[0]
spm1 = self.sp-1
# Left arc:
la_positions = self.stacks[la_rows, self.sp[la_rows]]
# Edge from top-of-buffer to top-of-stack.
self.heads[la_rows, la_positions] = self.bp[la_rows]
self.labels[la_rows, la_positions] = action_labels[la_rows]
# Pop an item from the stack.
self.sp[la_rows] -= 1
# Right arc:
ra_positions = self.stacks[ra_rows, self.sp[ra_rows]]
# Edge from second item in the stack to the first item.
self.heads[ra_rows, ra_positions] = self.stacks[ra_rows, spm1[ra_rows]]
self.labels[ra_rows, ra_positions] = action_labels[ra_rows]
# Pop an item from the stack.
self.sp[ra_rows] -= 1
# Shift:
# Put the first item in the buffer on top of the stack.
self.sp[sh_rows] += 1
self.stacks[sh_rows, self.sp[sh_rows]] = self.bp[sh_rows]
# Move the buffer pointer one step forward.
self.bp[sh_rows] += 1
We'll now define the neural network used in the parser by Kiperwasser and Goldberg (2016). The model consists of an encoder based on word and part-of-speech-tag embeddings, and a 3-layer BiLSTM. The outputs from the BiLSTM are then fed into the two classifiers that predict the next action and label. The classifiers are simple feedforward neural networks with one hidden layer and ReLU activations.
class TransitionClassifier(nn.Module):
def __init__(self, fields, word_emb_dim, pos_emb_dim,
rnn_size, rnn_depth, mlp_hidden_size,
n_stack_features,
update_pretrained=False):
super().__init__()
word_field = fields[0][1]
pos_field = fields[1][1]
action_field = fields[4][1]
label_field = fields[5][1]
n_actions = len(action_field.vocab)
n_labels = len(label_field.vocab)
# Sentence encoder module.
self.encoder = RNNEncoder(word_field, word_emb_dim, pos_field, pos_emb_dim, rnn_size, rnn_depth,
update_pretrained)
# MLPs for classifying actions and labels.
mlp_input_size = 2*rnn_size*(1+n_stack_features)
self.action_mlp = nn.Sequential(nn.Linear(mlp_input_size, mlp_hidden_size),
nn.ReLU(),
nn.Linear(mlp_hidden_size, n_actions))
self.label_mlp = nn.Sequential(nn.Linear(mlp_input_size, mlp_hidden_size),
nn.ReLU(),
nn.Linear(mlp_hidden_size, n_labels))
self.pad_id = action_field.vocab.stoi[action_field.pad_token]
self.sh_id = action_field.vocab.stoi['sh']
self.la_id = action_field.vocab.stoi['la']
self.ra_id = action_field.vocab.stoi['ra']
# Loss function that we will use during training.
self.loss = torch.nn.CrossEntropyLoss(reduction='none')
def word_tag_dropout(self, words, postags, p_drop):
# Randomly replace some of the positions in the word and postag tensors with a zero.
# This solution is a bit hacky because we assume that zero corresponds to the "unknown" token.
w_dropout_mask = (torch.rand(size=words.shape, device=words.device) > p_drop).long()
p_dropout_mask = (torch.rand(size=words.shape, device=words.device) > p_drop).long()
return words*w_dropout_mask, postags*p_dropout_mask
def forward(self, words, postags, actions, action_labels, action_features):
if self.training:
# If we are training, apply the word/tag dropout to the word and tag tensors.
words, postags = self.word_tag_dropout(words, postags, 0.25)
n_sentences, n_actions = actions.shape
encoded = self.encoder(words, postags)
action_input = self.encode_action_features(encoded, action_features)
action_output = self.action_mlp(action_input)
label_output = self.label_mlp(action_input)
return self.compute_loss(actions, action_output) + self.compute_loss(action_labels, label_output)
def compute_loss(self, actions, action_output):
actions = actions.flatten()
pad_mask = (actions != self.pad_id).float()
action_loss = self.loss(action_output, actions)
return action_loss.dot(pad_mask) / pad_mask.sum()
def encode_action_features(self, encoded, action_features):
n_sentences, n_actions, _ = action_features.shape
row_ix = torch.arange(n_sentences, device=action_features.device, dtype=torch.long)
a_ix = action_features.transpose(0, 2)
out = encoded[row_ix, a_ix].transpose(0, 2)
return out.reshape(n_sentences*n_actions, -1)
def predict(self, encoded, action_features, no_la_rows, no_ra_rows, no_sh_rows):
n_sentences, n_features = action_features.shape
action_features = action_features.view(n_sentences, 1, n_features)
action_input = self.encode_action_features(encoded, action_features)
action_output = self.action_mlp(action_input)
label_output = self.label_mlp(action_input)
action_output[no_sh_rows, self.sh_id] = -np.inf
action_output[no_la_rows, self.la_id] = -np.inf
action_output[no_ra_rows, self.ra_id] = -np.inf
return action_output.argmax(dim=1), label_output.argmax(dim=1)
And here is the sentence encoding part. This is a straightforward application of techniques we've seen in the past, with the small twist that we're using embeddings not only for the words but also the part-of-speech tags.
class RNNEncoder(nn.Module):
def __init__(self, word_field, word_emb_dim, pos_field, pos_emb_dim, rnn_size, rnn_depth, update_pretrained):
super().__init__()
self.word_embedding = nn.Embedding(len(word_field.vocab), word_emb_dim)
if word_field.vocab.vectors is not None:
self.word_embedding.weight = nn.Parameter(word_field.vocab.vectors,
requires_grad=update_pretrained)
self.pos_embedding = nn.Embedding(len(pos_field.vocab), pos_emb_dim)
self.rnn = nn.LSTM(input_size=word_emb_dim+pos_emb_dim, hidden_size=rnn_size, batch_first=True,
bidirectional=True, num_layers=rnn_depth)
def forward(self, words, postags):
word_emb = self.word_embedding(words)
pos_emb = self.pos_embedding(postags)
word_pos_emb = torch.cat([word_emb, pos_emb], dim=2)
rnn_out, _ = self.rnn(word_pos_emb)
return rnn_out
As usual, we build a main function that loads the dataset, creates a model, and goes through the training loop, and prints some diagnostics at the end. This is similar to our previous examples, so we'll leave it without comment. The main difference is that there's a bit of additional overhead to compute the action sequences and to extract the relevant sentence positions.
While training, we print the unlabeled attachment score (UAS) and labeled attachment score (LAS) evaluated on the validation set. We usually reach UAS levels of about 0.88 when we use this English dataset, and LAS scores a bit lower.
At the end of the class, there are some auxiliary methods that call the action classifier sequentially and updates the parser configurations, depending on the selected actions.
class DependencyParser:
def __init__(self, lower=False):
bos = '<bos>'
eos = '<eos>'
none = '<none>'
self.WORD = torchtext.data.Field(init_token=bos, eos_token=eos, sequential=True,
lower=lower, batch_first=True)
self.POS = torchtext.data.Field(init_token=bos, eos_token=eos, sequential=True,
batch_first=True)
self.HEAD = torchtext.data.Field(pad_token=-1, eos_token=-1, use_vocab=False,
sequential=True, batch_first=True)
self.LABEL = torchtext.data.Field(pad_token=none, sequential=True,
unk_token=None, batch_first=True)
self.ACTION = torchtext.data.Field(pad_token=none, unk_token=None, sequential=True, batch_first=True)
self.fields = [('words', self.WORD), ('postags', self.POS),
('heads', self.HEAD), ('labels', self.LABEL),
('actions', self.ACTION), ('action_labels', self.LABEL)]
self.transition_system = ArcHybridSystem(n_stack_features=3)
self.device = 'cuda'
def train(self):
torch.manual_seed(1234)
random.seed(1234)
# Read training and validation data according to the predefined split.
train_examples = read_data('data/en_ewt-ud-train.conllu', self.fields, self.transition_system, False)
val_examples = read_data('data/en_ewt-ud-dev.conllu', self.fields, self.transition_system, True)
self.POS.build_vocab(train_examples)
self.ACTION.build_vocab(train_examples)
self.LABEL.build_vocab(train_examples)
self.transition_system.set_vocab(self.ACTION.vocab)
# Load the pre-trained word embeddings that come with the torchtext library.
use_pretrained = True
if use_pretrained:
print('We are using pre-trained word embeddings.')
self.WORD.build_vocab(train_examples, vectors="glove.840B.300d")
else:
print('We are training word embeddings from scratch.')
self.WORD.build_vocab(train_examples, max_size=10000)
self.model = TransitionClassifier(self.fields, word_emb_dim=300, pos_emb_dim=32,
rnn_size=256, rnn_depth=3, mlp_hidden_size=256,
n_stack_features=self.transition_system.n_stack_features,
update_pretrained=False)
self.model.to(self.device)
train_iterator = torchtext.data.BucketIterator(
train_examples,
device=self.device,
batch_size=64,
sort_key=lambda x: len(x.words),
repeat=False,
train=True,
sort=True)
val_iterator = torchtext.data.BucketIterator(
val_examples,
device=self.device,
batch_size=512,
sort_key=lambda x: len(x.words),
repeat=False,
train=True,
sort=True)
train_batches = list(train_iterator)
val_batches = list(val_iterator)
train_action_features = []
for batch in train_batches:
batch_actions = batch.actions.cpu().numpy()
action_features = self.transition_system.extract_action_features_batch(batch_actions)
features_tensor = torch.as_tensor(action_features, device=self.device)
train_action_features.append(features_tensor)
val_action_features = []
for batch in val_batches:
batch_actions = batch.actions.cpu().numpy()
action_features = self.transition_system.extract_action_features_batch(batch_actions)
features_tensor = torch.as_tensor(action_features, device=self.device)
val_action_features.append(features_tensor)
train_batches = list(zip(train_batches, train_action_features))
val_batches = list(zip(val_batches, val_action_features))
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0025, weight_decay=1e-5)
history = defaultdict(list)
n_epochs = 25
for i in range(1, n_epochs + 1):
t0 = time.time()
stats = Counter()
random.shuffle(train_batches)
self.model.train()
for batch, batch_action_features in train_batches:
loss = self.model(batch.words, batch.postags, batch.actions, batch.action_labels, batch_action_features)
optimizer.zero_grad()
loss.backward()
optimizer.step()
stats['train_loss'] += loss.item()
if self.device == 'cpu':
print('.', end='')
sys.stdout.flush()
if self.device == 'cpu':
print()
train_loss = stats['train_loss'] / len(train_batches)
history['train_loss'].append(train_loss)
t1 = time.time()
self.model.eval()
with torch.no_grad():
for batch, batch_action_features in val_batches:
loss = self.model(batch.words, batch.postags, batch.actions, batch.action_labels, batch_action_features)
stats['val_loss'] += loss.item()
predicted_heads, predicted_labels = self.parse_batch(batch)
n_tokens, n_corr_u, n_corr_l = self.evaluate(batch.heads, batch.labels, predicted_heads, predicted_labels)
stats['val_n_tokens'] += n_tokens
stats['val_n_corr_u'] += n_corr_u
stats['val_n_corr_l'] += n_corr_l
t2 = time.time()
val_loss = stats['val_loss'] / len(val_batches)
history['val_loss'].append(val_loss)
uas = stats['val_n_corr_u']/stats['val_n_tokens']
las = stats['val_n_corr_l']/stats['val_n_tokens']
history['uas'].append(uas)
history['las'].append(las)
print(f'Epoch {i:2}: train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, UAS: {uas:.4f}, LAS: {las:.4f}, time: {t2-t0:.4f}')
plt.plot(history['train_loss'])
plt.plot(history['val_loss'])
plt.plot(history['uas'])
plt.legend(['training loss', 'validation loss', 'UAS'])
def evaluate(self, gold_heads, gold_labels, predicted_heads, predicted_labels):
# Computes the relevant counters for computing the LAS and UAS.
pad_mask = (gold_heads != -1)
head_ok = (gold_heads == predicted_heads)
label_ok = (gold_labels == predicted_labels)
n_corr_u = (pad_mask & head_ok).sum().item()
n_corr_l = (pad_mask & head_ok & label_ok).sum().item()
n_tokens = pad_mask.sum().item()
return n_tokens, n_corr_u, n_corr_l
def sanity_check(self, actions):
n_sent, n_words, rnn_dim = self.encoded.shape
shift_bug = (self.transition_system.bp >= n_words-1) & (actions == 1)
if shift_bug.sum() > 0:
print(self.transition_system.bp)
print(actions)
def step(self):
# Carries out one parsing step: calling the action classifier, and then updating the configurations.
action_features = self.state.extract_features()
no_la, no_ra, no_sh = self.state.check_conditions()
actions, labels = self.model.predict(self.state.encoded, action_features,
no_la, no_ra, no_sh)
self.state.update(actions, labels)
return actions, labels
def steps(self):
# Carries out parsing actions until all parsers have reached the end state.
end_action_id = self.ACTION.vocab.stoi['end']
while True:
actions, labels = self.step()
if (actions != end_action_id).sum().item() == 0:
return
def init_batch(self, batch):
# Initializes the parser for a batch of sentences.
self.state = self.transition_system.init_parsing(batch.words)
self.state.encoded = self.model.encoder(batch.words, batch.postags)
def parse_batch(self, batch):
# Initializes and parses a batch of sentences, then returns the resulting trees.
self.init_batch(batch)
self.steps()
return self.state.heads, self.state.labels
def init(self, sentences):
# Auxiliary method that is used when the input consists of word and part-of-speech strings.
examples = []
for tagged_words in sentences:
words = [w for w, _ in tagged_words]
tags = [t for _, t in tagged_words]
examples.append(torchtext.data.Example.fromlist([words, tags, [], [], [], []], self.fields))
dataset = torchtext.data.Dataset(examples, self.fields)
iterator = torchtext.data.Iterator(
dataset,
device=self.device,
batch_size=len(examples),
repeat=False,
train=False,
sort=False)
self.model.eval()
with torch.no_grad():
for batch in iterator:
self.init_batch(batch)
# hack: this is just to make the visualization a bit nicer
_, n_heads = self.state.heads.shape
self.state.heads += torch.arange(n_heads, dtype=torch.long, device=self.device)
def get_heads_and_labels(self):
heads = self.state.heads.cpu().numpy()
label_enc = self.state.labels.cpu().numpy()
labels = [[self.LABEL.vocab.itos[l] for l in row] for row in label_enc]
return heads, labels
def parse(self, sentences):
self.init(sentences)
self.steps()
return self.get_heads_and_labels()
parser = DependencyParser()
parser.train()
The following interactive demo requires that you have NLTK and graphviz installed, as in the notebook from last week.
import nltk
# Download the tokenizer and part-of-speech tagger models if you haven't done it before.
# nltk.download('punkt')
# nltk.download('averaged_perceptron_tagger')
import warnings
warnings.filterwarnings('ignore')
# Put the directory where 'dot' is located first in the PATH.
import os
os.environ['PATH'] = '/opt/miniconda3/bin:' + os.environ['PATH']
The following class is used for the interactive demo.
class ParserDemo:
def __init__(self, sentence):
self.tokenized = nltk.word_tokenize(sentence)
self.tagged = nltk.pos_tag(self.tokenized)
parser.init([self.tagged])
self.show_state()
def show_state(self):
sp = parser.state.sp[0].item()
stack = parser.state.stacks[0, :sp+1].cpu().numpy()
print('Stack:', stack)
print('Buffer position:', parser.state.bp[0].item())
def step(self):
actions, action_labels = parser.step()
actions = [ parser.ACTION.vocab.itos[a] for a in actions.cpu().numpy() ]
action_labels = [ parser.LABEL.vocab.itos[l] for l in action_labels.cpu().numpy() ]
if actions[0] not in ['la', 'ra']:
print('Selected action:', actions[0])
else:
print(f'Selected action and label: {actions[0]} {action_labels[0]}')
self.show_state()
def draw_tree(self):
heads, labels = parser.get_heads_and_labels()
nltk_str = '\n'.join(f'{w} _ {h} {l}' for (w, h, l) in zip(self.tokenized, heads[0][1:], labels[0][1:]))
return nltk.DependencyGraph(nltk_str)
demo = ParserDemo('She lives in a house in my garden')
demo.draw_tree()
If you run the following cell repeatedly, you'll see how the parser selects actions and updates the configuration and edges. Eventually, you'll get a complete tree and the parser will go into the end state.
step = demo.step()
demo.draw_tree()
Here is an example how we can tokenize and tag a sentence using NLTK:
nltk.pos_tag(nltk.word_tokenize('The big dog lives in its little house.'))
We make a utility function that calls NLTK to tokenize and part-of-speech-tag the sentence, and then call the parser we trained above to find the head position for each word. We then print the words, tags and head positions as three separate columns.
def parse_sentence(sentence):
tokenized = nltk.word_tokenize(sentence)
tagged = nltk.pos_tag(tokenized)
heads, labels = parser.parse([tagged])
for i, ((word, tag), head, label) in enumerate(zip(tagged, heads[0][1:], labels[0][1:]), 1):
print(f'{i:2} {word:10} {tag:4} {head} {label}')
Here is the result of parsing an example sentence.
parse_sentence('The big dog lives in its little house.')
It's probably a bit more understandable to look at the parse trees visually than the column-based format above. NLTK also includes functionality to draw trees.
Drawing the trees in a notebook requires the installation of graphviz utilities. If you use pip to install packages, install this package; if you use conda, install this and this.
When running this example, I had to add the location of the dot
utility to the PATH
. This might or might not be necessary in your case.
def draw_sentence(sentence):
tokenized = nltk.word_tokenize(sentence)
tagged = nltk.pos_tag(tokenized)
heads, labels = parser.parse([tagged])
nltk_str = '\n'.join(f'{w} _ {h} {l}' for (w, h, l) in zip(tokenized, heads[0][1:], labels[0][1:]))
return nltk.DependencyGraph(nltk_str)
We can now draw a tree for the example we saw above.
draw_sentence('The big dog lives in its little house.')
draw_sentence('"You\'ve made a big mess," said my mother.')