Training word embeddings with the SGNS algorithm

In this notebook, we'll see a PyTorch implementation of a well-known training algorithm for word embeddings, Mikolov's Skip-gram with negative sampling.

Please note that the example is somewhat incomplete, because in a realistic implementation we would also save the embeddings when training is finished. In this implementation, we'll just print the similarities to some test instances.

In [1]:
import torch
import torch.nn as nn

import numpy as np

import sys, time, os
from collections import Counter

Preliminaries: building the vocabulary and negative sampling table

We will first make a function that goes through the training corpus and finds the most frequent words, which will be used for the vocabulary. A special dummy token will be used as a stand-in for the words that are less frequent.

In addition, we will create the table that will be used for negative sampling. Each word will be sampled with a probability that is proportional to its frequency to the power of a constant (called ns_exp here). Sampling words randomly can be a bit tricky to implement efficiently, and we'll use a trick that was used in the word2vec software: we'll make a large array where each word will occur a number of times that is roughly proportional to its probability.

In [2]:
def make_ns_table(params):
    corpus = params['corpus']
    voc_size = params['voc-size']
    ns_table_size = params['ns-table-size']
    unk_str = params['unknown-str']
    lowercase = params['lowercase']
    ns_exp = params['ns-exp']

    # This is what we'll use to store the frequencies.
    freqs = Counter()

    print('Building vocabulary and sampling table...')    

    # First, build a full frequency table from the whole corpus.
    with open(corpus) as f:
        for i, line in enumerate(f, 1):
            if lowercase:
                line = line.lower()
            freqs.update(line.split())
            if i % 50000 == 0:
                sys.stdout.write('.')
                sys.stdout.flush()
            if i % 1000000 == 0:
                sys.stdout.write(' ')
                sys.stdout.write(str(i))
                sys.stdout.write('\n')
                sys.stdout.flush()
    print()

    # Sort the frequencies, then select the most frequent words as the vocabulary.
    freqs_sorted = sorted(freqs.items(),
                          key=lambda p: (p[1], p[0]),
                          reverse=True)
    if len(freqs_sorted) > voc_size-1:
        sum_freq_pruned = sum(f for _, f in freqs_sorted[voc_size-1:])
    else:
        sum_freq_pruned = 1

    # We'll add a special dummy to represent the occurrences of low-frequency words.
    freqs_sorted = [(unk_str, sum_freq_pruned)] + freqs_sorted[:voc_size-1]

    # Now, we'll compute the negative sampling table.
    # The negative sampling probabilities are proportional to the frequencies
    # to the power of a constant (typically 0.75).
    ns_table = {}
    sum_freq = 0
    for w, freq in freqs_sorted:
        ns_freq = freq ** ns_exp
        ns_table[w] = ns_freq
        sum_freq += ns_freq

    # Convert the negative sampling probabilities to integers, in order to make
    # sampling a bit faster and easier.
    # We return a list of tuples consisting of:
    # - the word
    # - its frequency in the training data
    # - the number of positions reserved for this word in the negative sampling table
    scaler = ns_table_size / sum_freq
    return [(w, freq, int(round(ns_table[w]*scaler))) for w, freq in freqs_sorted]

And then two utility functions to load and save the negative sampling table.

In [3]:
def load_ns_table(filename):
    with open(filename) as f:
        out = []
        for l in f:
            t = l.split()
            out.append((t[0], int(t[1]), int(t[2])))
        return out

def save_ns_table(table, filename):
    with open(filename, 'w') as f:
        for w, fr, ns in table:
            print(f'{w} {fr} {ns}', file=f)

Generating target–context pairs

The following class is used to go through the training file line by line, and generate positive training instances (pairs consisting of a target word and a context word). Here, we will use all the preprocessing intricacies described in Mikolov's paper.

The batches method will generate one batch at a time, containing a number of positive training instances coded as integers. The negative training instances will be created elsewhere.

In [4]:
class SGNSContextGenerator:

    def __init__(self, ns_table, params):

        # The name of the training file.
        self.corpus = params['corpus']
        
        # The string-to-integer mapping for the vocabulary.
        self.voc = { w:i for i, (w, _, _ ) in enumerate(ns_table) }

        # The number of positive instances we'll create in each batch.
        self.batch_size = params['batch-size']

        # The maximal width of the context window.
        self.ctx_width = params['context-width']

        # Whether we should 
        self.lowercase = params['lowercase']
        
        self.word_count = 0
        
        # We define the pruning probabilities for each word as in Mikolov's paper.
        total_freq = sum(f for _, f, _ in ns_table)
        self.prune_probs = {}
        for w, f, _ in ns_table:
            self.prune_probs[w] = 1 - np.sqrt(params['prune-threshold'] * total_freq / f)

    def prune(self, tokens):
        ps = np.random.random(size=len(tokens))
        # Remove some words from the input with probabilities defined by their frequencies.
        return [ w for w, p in zip(tokens, ps) if p >= self.prune_probs.get(w, 0) ]

    def batches(self):

        widths = np.random.randint(1, self.ctx_width+1, size=self.batch_size)
        width_ix = 0

        self.word_count = 0
        
        with open(self.corpus) as f:
            out_t = []
            out_c = []
            for line in f:

                # Process one line: lowercase and split into tokens.
                if self.lowercase:
                    line = line.lower()
                tokens = line.split()
                self.word_count += len(tokens)

                # Remove some words, then encode as integers.
                encoded = [ self.voc.get(t, 0) for t in self.prune(tokens) ]

                for i, t in enumerate(encoded):

                    # The context width is selected uniformly between 1 and the maximal width.
                    w = widths[width_ix]
                    width_ix += 1

                    # Compute start and end positions for the context.
                    start = max(0, i-w)
                    end = min(i+w+1, len(encoded))

                    # Finally, generate target--context pairs.
                    for j in range(start, end):
                        if j != i:
                            out_t.append(encoded[i])
                            out_c.append(encoded[j])
                            
                            # If we've generate enough pairs, yield a batch.
                            # Each batch is a list of targets and a list of corresponding contexts.
                            if len(out_t) == self.batch_size:
                                yield out_t, out_c
                                
                                # After coming back, reset the batch.
                                widths = np.random.randint(1, self.ctx_width+1, size=self.batch_size)
                                width_ix = 0
                                out_t = []
                                out_c = []
                    
            print('End of file.')
            if len(out_t) > 0:
                # Yield the final batch.
                yield out_t, out_c

Defining the model

Next, we implement the neural network that defines the model. The parameters just consist of two sets of embeddings: one for the target words, and one for the contexts.

The forward step is fairly trivial: we just compute the dot products of the target and context embeddings. As usual, the most annoying part is to keep track of the tensor shapes.

We also add a couple of methods that allow us to inspect the model: computing the cosine similarity between the embeddings for two words, and finding the nearest neighbor lists of a set of words.

In [5]:
class SGNSModel(nn.Module):

    def __init__(self, voc, params):
        super().__init__()
        
        voc_size = len(voc)
        
        # Target word embeddings
        self.w = nn.Embedding(voc_size, params['emb-dim'])
        # Context embeddings
        self.c = nn.Embedding(voc_size, params['emb-dim'])
        
        # Some things we need to print nearest neighbor lists for diagnostics.
        self.voc = voc
        self.ivoc = { i:w for w, i in voc.items() }

    def forward(self, tgt, ctx):       
        # tgt is a 1-dimensional tensor containing target word ids
        # ctx is a 2-dimensional tensor containing positive and negative context ids for each target
        
        # Look up the embeddings for the target words.
        # shape: (batch size, embedding dimension)
        tgt_emb = self.w(tgt)
        
        n_batch, emb_dim = tgt_emb.shape
        n_ctx = ctx.shape[1]
        
        # View this as a 3-dimensional tensor, with
        # shape (batch size, 1, embedding dimension)
        tgt_emb = tgt_emb.view(n_batch, 1, emb_dim)

        # Look up the embeddings for the positive and negative context words.
        # shape: (batch size, nbr contexts, emb dim)
        ctx_emb = self.c(ctx)

        # Transpose the tensor for matrix multiplication
        # shape: (batch size, emb dim, nbr contexts)
        ctx_emb = ctx_emb.transpose(1, 2)

        # Compute the dot products between target word embeddings and context
        # embeddings. We express this as a batch matrix multiplication (bmm).
        # shape: (batch size, 1, nbr contexts)
        dots = tgt_emb.bmm(ctx_emb)

        # View this result as a 2-dimensional tensor.
        # shape: (batch size, nbr contexts)
        dots = dots.view(n_batch, n_ctx)

        return dots
    
    
    def nearest_neighbors(self, words, n_neighbors):
        
        # Encode the words as integers, and put them into a PyTorch tensor.
        words_ix = torch.as_tensor([self.voc[w] for w in words])
        
        # Look up the embeddings for the test words.
        voc_size, emb_dim = self.w.weight.shape
        test_emb = self.w(words_ix).view(len(words), 1, emb_dim)

        # Also, get the embeddings for all words in the vocabulary.
        all_emb = self.w.weight.view(1, voc_size, emb_dim)

        # We'll use a cosine similarity function to find the most similar words.
        # The .view kludgery above is needed for the batch-wise cosine similarity.
        sim_func = nn.CosineSimilarity(dim=2)
        scores = sim_func(test_emb, all_emb)
        # The shape of scores is (nbr of test words, total number of words)
                
        # Find the top-scoring columns in each row.
        if not n_neighbors:
            n_neighbors = self.n_testwords_neighbors
        near_nbr = scores.topk(n_neighbors+1, dim=1)
        values = near_nbr.values[:,1:]
        indices = near_nbr.indices[:, 1:]
        
        # Finally, map word indices back to strings, and put the result in a list.
        out = []
        for ixs, vals in zip(indices, values):
            out.append([ (self.ivoc[ix.item()], val.item()) for ix, val in zip(ixs, vals) ])
        return out
        
        
    def cosine_similarity(self, word1, word2):        
        # We just look up the two embeddings and use PyTorch's built-in cosine similarity.
        v1 = self.w(torch.as_tensor(self.voc[word1]))
        v2 = self.w(torch.as_tensor(self.voc[word2]))
        sim = nn.CosineSimilarity(dim=0)
        return sim(v1, v2).item()

Training

The following class contains the training loop: it creates a batch of positive target–context pairs, generates negative samples, and then updates the embedding model.

In [6]:
class SGNSTrainer:

    def __init__(self, instance_gen, model, ns_table, params):
        self.instance_gen = instance_gen
        self.model = model
        self.n_epochs = params['n-epochs']
        self.max_words = params.get('max-words')
        n_batch = params['batch-size']
        self.n_ns = params['n-neg-samples']

        if params['optimizer'] == 'adam':
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=params['lr'])
        elif params['optimizer'] == 'sgd':
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=params['lr'])

        # We'll use a binary cross-entropy loss, since we have a binary classification problem:
        # distinguishing positive from negative contexts.
        self.loss = nn.BCEWithLogitsLoss()

        # Build the negative sampling table.
        ns_table_expanded = []
        for i, (_, _, count) in enumerate(ns_table):
            ns_table_expanded.extend([i] * count)
        self.ns_table = torch.as_tensor(ns_table_expanded)
        
        # Define the "gold standard" that we'll use to compute the loss.
        # It consists of a column of ones, and then a number of columns of zeros.
        # This structure corresponds to the positive and negative contexts, respectively.
        y_pos = torch.ones((n_batch, 1))
        y_neg = torch.zeros((n_batch, self.n_ns))
        self.y = torch.cat([y_pos, y_neg], dim=1)

        # Some things we need to print nearest neighbor lists for diagnostics.
        #self.voc = instance_gen.voc
        #self.ivoc = { i:w for w, i in self.voc.items() }
        self.testwords = params['testwords']
        self.n_testwords_neighbors = params['n-testwords-neighbors']

        self.epoch = 0
        
    def print_test_nearest_neighbors(self):
                
        nn_lists = self.model.nearest_neighbors(self.testwords, self.n_testwords_neighbors)
        
        # For each test word, print the most similar words.
        for w, nn_list in zip(self.testwords, nn_lists):
            print(w, end=':\n')
            for nn, sim in nn_list:
                print(f' {nn} ({sim:.3f})', end='')
            print()
        
        print('------------------------------------')
        
    def make_negative_sample(self, batch_size):
        neg_sample_ixs = torch.randint(len(self.ns_table), (batch_size, self.n_ns))
        return self.ns_table.take(neg_sample_ixs)
            
    def train(self):

        print_interval = 5000000
        
        while self.epoch < self.n_epochs:
            print(f'Epoch {self.epoch+1}.')

            # For diagnostics.
            n_pairs = 0
            sum_loss = 0
            total_pairs = 0
            n_batches = 0
            t0 = time.time()
            
            for t, c_pos in self.instance_gen.batches():

                batch_size = len(t)
                
                # Put the encoded target words and contexts into PyTorch tensors.
                t = torch.as_tensor(t)                
                c_pos = torch.as_tensor(c_pos)
                c_pos = c_pos.view(batch_size, 1)
                
                # Generate a sample of fake context words.
                # shape: (batch size, number of negative samples)
                c_neg = self.make_negative_sample(batch_size)
                
                # Combine positive and negative contexts.
                # shape: (batch size, 1 + nbr neg samples)
                c = torch.cat([c_pos, c_neg], dim=1)
                
                self.optimizer.zero_grad()

                # Compute the output from the model.
                # That is, the dot products between target embeddings
                # and context embeddings.
                scores = self.model(t, c)

                # Compute the loss with respect to the gold standard.
                loss = self.loss(scores, self.y[:batch_size])

                # Compute gradients and update the embeddings.
                loss.backward()
                self.optimizer.step()

                # We'll print some diagnostics periodically.
                sum_loss += loss.item()
                n_pairs += batch_size
                n_batches += 1
                if n_pairs > print_interval:
                    total_words = self.instance_gen.word_count
                    total_pairs += n_pairs
                    t1 = time.time()                    
                    print(f'Pairs: {total_pairs}, words: {total_words}, loss: {sum_loss / n_batches:.4f}, time: {t1-t0:.2f}')
                    self.print_test_nearest_neighbors()
                    if self.max_words and total_words > self.max_words:
                        break
                    n_pairs = 0
                    sum_loss = 0
                    n_batches = 0
                    t0 = time.time()
                    
            self.epoch += 1

Putting all the pieces together

Now, we have all the pieces that we need to train the model. The following code just calls the other functions that we developed above. It also contains the parameters that control the program's behavior.

To keep things fast, we'll just train on the first 50 million words. In a realistic implementation, we'd probably use a larger dataset and also run for several epochs.

When we run this code, you will see that the similarity lists for the test words gradually start to make sense. After processing 50 million words, most of the lists shouls be sensible. The quality will improve further if you use more training data.

In [7]:
model = None

def main():
    global model
    params = {
        'corpus': '../data/wikipedia-2009-subset.txt', # Training data file
        'device': 'cuda', # Device

        'n-neg-samples': 5, # Number of negative samples per positive sample
        'emb-dim': 64, # Embedding dimensionality
        
        'n-epochs': 1, # Number of epochs
        'max-words': 50000000, # How many words to consider in one epoch
        
        'batch-size': 1<<20, # Number of positive training instances in one batch
        'context-width': 5, # Maximal possible context width
        'prune-threshold': 1e-3, # Pruning threshold (see Mikolov's paper)
        'voc-size': 100000, # Maximal vocabulary size
        'ns-table-file': 'ns_table.txt', # Where to store the negative sampling table
        'ns-table-size': 1<<24, # Size of negative sampling table
        'ns-exp': 0.75, # Smoothing parameter for negative sampling distribution (see paper)
        'unknown-str': '<UNKNOWN>', # Dummy token for low-frequency words
        'lowercase': True, # Whether to lowercase the text
        'optimizer': 'adam', # Which gradient descent optimizer to use
        'lr': 1e-1, # Learning rate for the  optimizer

        # The test words for which we print the nearest neighbors periodically
        'testwords': ['apple', 'terrible', 'sweden', '1979', 'write', 'gothenburg'],
        # Number of nearest neighbors
        'n-testwords-neighbors': 5,
    }
    
    if params['device'] == 'cuda' and torch.cuda.is_available():
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
        print('Running on CUDA device.')
    else:
        torch.set_default_tensor_type(torch.FloatTensor)
        print('Running on CPU.')

    # If we didn't already create the vocabulary and negative 
    # sampling table, we'll do that now.
    if os.path.exists(params['ns-table-file']):
        ns_table = load_ns_table(params['ns-table-file'])
    else:
        ns_table = make_ns_table(params)
        save_ns_table(ns_table, params['ns-table-file'])

    ctx_gen = SGNSContextGenerator(ns_table, params)
    model = SGNSModel(ctx_gen.voc, params)
    trainer = SGNSTrainer(ctx_gen, model, ns_table, params)

    trainer.train()
        
main()
Running on CUDA device.
Epoch 1.
Pairs: 5242880, words: 1488473, loss: 2.5872, time: 5.10
apple:
 bogies (0.543) nationale (0.494) unley (0.493) bien (0.480) asian (0.477)
terrible:
 henchmen (0.488) circassian (0.482) fidelis (0.465) uprightness (0.465) scaliger (0.450)
sweden:
 unsc (0.520) aurore (0.481) ghosting (0.472) szlachta (0.465) 41.2 (0.465)
1979:
 97.35 (0.517) 830 (0.493) mesothelioma (0.484) unpredictability (0.479) sustainer (0.468)
write:
 integrating (0.520) mwss (0.506) phidias (0.502) mns (0.499) sedaka (0.487)
gothenburg:
 disturbingly (0.547) fourths (0.511) sat (0.507) fester (0.501) linn (0.497)
------------------------------------
Pairs: 10485760, words: 2976823, loss: 1.3627, time: 5.09
apple:
 rarely (0.522) bogies (0.515) swissair (0.504) asian (0.503) robustus (0.466)
terrible:
 classrooms (0.489) henchmen (0.488) irritates (0.486) circassian (0.478) uprightness (0.454)
sweden:
 antichrist (0.518) aurore (0.500) 41.2 (0.498) science (0.495) barnett (0.490)
1979:
 cleared (0.590) 830 (0.558) akel (0.530) shield (0.498) peripheral (0.493)
write:
 integrating (0.613) serving (0.513) mns (0.501) drawing (0.492) designation (0.492)
gothenburg:
 disturbingly (0.544) fester (0.512) fourths (0.492) linn (0.488) helicarrier (0.481)
------------------------------------
Pairs: 15728640, words: 4465923, loss: 0.7809, time: 5.22
apple:
 rarely (0.658) least (0.646) family (0.607) women (0.605) role (0.589)
terrible:
 nevada (0.581) classrooms (0.565) merging (0.532) star (0.532) ports (0.529)
sweden:
 spoken (0.616) 9.8 (0.601) country (0.594) historic (0.591) p. (0.587)
1979:
 cleared (0.687) key (0.655) golden (0.624) controversy (0.621) clearing (0.617)
write:
 integrating (0.666) drawing (0.650) designation (0.643) promotion (0.604) remaining (0.595)
gothenburg:
 disturbingly (0.525) fester (0.507) helicarrier (0.491) fourths (0.470) resonances (0.466)
------------------------------------
Pairs: 20971520, words: 5954108, loss: 0.6051, time: 5.11
apple:
 family (0.719) least (0.707) rarely (0.703) doctor (0.688) night (0.679)
terrible:
 nevada (0.649) classrooms (0.620) merging (0.605) ports (0.589) wield (0.575)
sweden:
 brisbane (0.679) role (0.675) christians (0.673) theoretical (0.669) performances (0.665)
1979:
 mine (0.710) clearing (0.703) partners (0.696) frontier (0.690) mi (0.677)
write:
 drawing (0.723) integrating (0.703) designation (0.676) promotion (0.676) barry (0.660)
gothenburg:
 fester (0.492) deschanel (0.490) resonances (0.487) helicarrier (0.482) soundtracks (0.481)
------------------------------------
Pairs: 26214400, words: 7446624, loss: 0.5366, time: 5.11
apple:
 interpreter (0.707) yuan (0.687) 5.2 (0.679) rarely (0.678) parsons (0.676)
terrible:
 nevada (0.637) merging (0.632) wield (0.619) classrooms (0.616) discussing (0.608)
sweden:
 brisbane (0.683) institute (0.674) performances (0.646) jazz (0.641) harbor (0.641)
1979:
 mine (0.675) frontier (0.674) clearing (0.673) partners (0.671) crawford (0.666)
write:
 drawing (0.703) integrating (0.692) internet (0.661) tuesday (0.660) barry (0.642)
gothenburg:
 whalers (0.530) 155 (0.523) soundtracks (0.519) garmisch (0.510) resonances (0.509)
------------------------------------
Pairs: 31457280, words: 8936660, loss: 0.5010, time: 4.83
apple:
 interpreter (0.703) parsons (0.637) carlisle (0.628) albany (0.626) newcomer (0.626)
terrible:
 merging (0.618) wield (0.617) discussing (0.592) classrooms (0.574) silesia (0.564)
sweden:
 institute (0.630) brisbane (0.625) louisiana (0.599) harbor (0.590) painters (0.577)
1979:
 crawford (0.613) manual (0.585) tong (0.584) hanger (0.577) frontier (0.577)
write:
 integrating (0.629) internet (0.624) tuesday (0.620) drawing (0.611) ruled (0.597)
gothenburg:
 whalers (0.568) 155 (0.566) soundtracks (0.547) cyprus (0.540) nickelodeon (0.532)
------------------------------------
Pairs: 36700160, words: 10426293, loss: 0.4778, time: 4.69
apple:
 interpreter (0.643) monarch (0.599) newcomer (0.576) carlisle (0.573) taste (0.566)
terrible:
 merging (0.602) wield (0.599) plo (0.580) discussing (0.572) bundles (0.563)
sweden:
 brisbane (0.616) louisiana (0.611) northeast (0.602) institute (0.598) billionaire (0.593)
1979:
 tong (0.590) crawford (0.586) hiller (0.570) hanger (0.564) network (0.561)
write:
 internet (0.604) tuesday (0.593) gem (0.584) wary (0.578) integrating (0.575)
gothenburg:
 155 (0.586) whalers (0.578) soundtracks (0.560) nickelodeon (0.555) townspeople (0.554)
------------------------------------
Pairs: 41943040, words: 11917222, loss: 0.4605, time: 4.76
apple:
 interpreter (0.591) monarch (0.559) crimson (0.538) harbor (0.535) carlisle (0.535)
terrible:
 merging (0.590) revolutionary (0.586) bundles (0.583) weather (0.573) wield (0.573)
sweden:
 lloyd (0.638) louisiana (0.634) billionaire (0.632) brisbane (0.632) los (0.628)
1979:
 2007 (0.624) resigned (0.624) january (0.617) 1961 (0.615) diego (0.613)
write:
 internet (0.633) wary (0.608) kevin (0.593) instructions (0.592) catalytic (0.590)
gothenburg:
 155 (0.581) cavities (0.571) townspeople (0.567) whalers (0.564) nickelodeon (0.559)
------------------------------------
Pairs: 47185920, words: 13407210, loss: 0.4486, time: 4.72
apple:
 interpreter (0.578) profession (0.571) monarch (0.561) saskatchewan (0.559) crimson (0.554)
terrible:
 witness (0.608) bundles (0.590) weather (0.588) holder (0.578) webs (0.575)
sweden:
 nbc (0.678) lloyd (0.652) brisbane (0.650) 1970 (0.649) shadow (0.648)
1979:
 january (0.697) 2007 (0.697) 1961 (0.680) 13 (0.679) 16 (0.678)
write:
 internet (0.663) opportunity (0.639) instructions (0.636) beer (0.630) revenge (0.621)
gothenburg:
 cavities (0.582) vancouver (0.573) 155 (0.559) kaiserslautern (0.558) townspeople (0.554)
------------------------------------
Pairs: 52428800, words: 14897126, loss: 0.4391, time: 4.70
apple:
 profession (0.642) crimson (0.595) monarch (0.591) choice (0.590) excise (0.588)
terrible:
 witness (0.621) strict (0.613) webs (0.587) bundles (0.579) weather (0.579)
sweden:
 1970 (0.670) 1867 (0.667) 1967 (0.661) premiered (0.653) 1918 (0.652)
1979:
 january (0.749) 1845 (0.734) 1914 (0.730) december (0.722) 16 (0.720)
write:
 opportunity (0.676) anything (0.671) internet (0.656) revenge (0.650) stop (0.646)
gothenburg:
 cavities (0.583) vancouver (0.573) kaiserslautern (0.564) hoover (0.554) helicarrier (0.554)
------------------------------------
Pairs: 57671680, words: 16385059, loss: 0.4319, time: 4.73
apple:
 contribution (0.641) profession (0.630) chosen (0.619) morality (0.607) concentrated (0.598)
terrible:
 strict (0.651) witness (0.636) if (0.618) webs (0.606) unless (0.601)
sweden:
 1867 (0.701) denmark (0.691) 1967 (0.687) 1953 (0.684) 1939 (0.680)
1979:
 1845 (0.770) january (0.770) 1923 (0.769) 1921 (0.768) 1914 (0.766)
write:
 anything (0.692) someone (0.691) keep (0.687) opportunity (0.687) revenge (0.674)
gothenburg:
 cavities (0.577) kaiserslautern (0.560) vancouver (0.555) hoover (0.551) scepter (0.550)
------------------------------------
Pairs: 62914560, words: 17874409, loss: 0.4262, time: 4.96
apple:
 contribution (0.655) concentrated (0.590) planned (0.579) recognised (0.575) morality (0.571)
terrible:
 strict (0.671) witness (0.644) if (0.637) webs (0.634) crimes (0.633)
sweden:
 denmark (0.727) 1974 (0.720) 1923 (0.710) spain (0.710) 1939 (0.709)
1979:
 1923 (0.806) 1921 (0.801) 1924 (0.791) 1939 (0.783) 1906 (0.782)
write:
 someone (0.741) anything (0.714) revenge (0.713) keep (0.697) opportunity (0.681)
gothenburg:
 cavities (0.569) kaiserslautern (0.558) vancouver (0.547) helicarrier (0.539) hoover (0.539)
------------------------------------
Pairs: 68157440, words: 19366715, loss: 0.4216, time: 5.15
apple:
 contribution (0.619) log (0.578) honus (0.576) conversion (0.570) armour (0.566)
terrible:
 crimes (0.646) strict (0.643) hope (0.643) webs (0.642) good (0.638)
sweden:
 spain (0.764) poland (0.746) ireland (0.744) nova (0.732) greece (0.729)
1979:
 1923 (0.821) 1969 (0.813) 1977 (0.811) 1921 (0.806) 1906 (0.802)
write:
 someone (0.784) anything (0.726) revenge (0.723) discover (0.712) thinks (0.699)
gothenburg:
 kaiserslautern (0.557) cavities (0.545) vancouver (0.545) helicarrier (0.530) episcopalian (0.527)
------------------------------------
Pairs: 73400320, words: 20854947, loss: 0.4181, time: 5.19
apple:
 dynamic (0.609) conversion (0.595) manufactured (0.589) developers (0.589) interactive (0.586)
terrible:
 violent (0.671) crimes (0.661) webs (0.640) strict (0.639) hope (0.637)
sweden:
 spain (0.785) greece (0.781) poland (0.768) ireland (0.762) italy (0.756)
1979:
 1969 (0.835) 1977 (0.831) 1964 (0.810) 1953 (0.808) 1923 (0.805)
write:
 someone (0.792) discover (0.713) advice (0.711) revenge (0.709) anything (0.708)
gothenburg:
 kaiserslautern (0.559) vancouver (0.552) helicarrier (0.528) moultrie (0.526) episcopalian (0.526)
------------------------------------
Pairs: 78643200, words: 22347184, loss: 0.4146, time: 5.15
apple:
 microsoft (0.642) dynamic (0.638) proof (0.625) developers (0.624) manufactured (0.615)
terrible:
 violent (0.693) spies (0.654) crimes (0.649) strict (0.644) humanity (0.642)
sweden:
 greece (0.802) ireland (0.783) spain (0.782) russia (0.770) poland (0.766)
1979:
 1969 (0.846) 1977 (0.843) 1953 (0.833) 1961 (0.830) 1926 (0.830)
write:
 someone (0.775) says (0.704) knowing (0.694) discover (0.693) thinks (0.691)
gothenburg:
 vancouver (0.563) kaiserslautern (0.562) moultrie (0.534) helicarrier (0.533) auburn (0.528)
------------------------------------
Pairs: 83886080, words: 23836611, loss: 0.4121, time: 5.26
apple:
 microsoft (0.664) layout (0.656) embedded (0.655) chip (0.653) hybrid (0.648)
terrible:
 violent (0.694) spies (0.688) humanity (0.653) `` (0.648) identity (0.646)
sweden:
 ireland (0.795) greece (0.795) russia (0.784) spain (0.778) italy (0.776)
1979:
 1977 (0.875) 1969 (0.866) 1954 (0.858) 1953 (0.855) 1984 (0.853)
write:
 someone (0.736) writes (0.718) suggesting (0.690) learned (0.686) follow (0.685)
gothenburg:
 vancouver (0.569) kaiserslautern (0.565) auburn (0.557) helicarrier (0.543) laguardia (0.537)
------------------------------------
Pairs: 89128960, words: 25325296, loss: 0.4101, time: 5.17
apple:
 hybrid (0.681) chip (0.663) layout (0.663) interactive (0.662) mpeg (0.660)
terrible:
 spies (0.699) violent (0.691) rejection (0.684) stronger (0.673) hatred (0.671)
sweden:
 italy (0.800) greece (0.788) ireland (0.782) poland (0.763) spain (0.762)
1979:
 1977 (0.880) 1959 (0.873) 1954 (0.870) 1984 (0.869) 1969 (0.867)
write:
 writes (0.708) learned (0.695) publish (0.694) message (0.690) suggesting (0.689)
gothenburg:
 auburn (0.574) kaiserslautern (0.566) vancouver (0.563) geauga (0.546) helicarrier (0.546)
------------------------------------
Pairs: 94371840, words: 26815142, loss: 0.4079, time: 5.15
apple:
 chip (0.696) hybrid (0.692) interactive (0.686) hardware (0.684) microsoft (0.678)
terrible:
 spies (0.701) hatred (0.694) violent (0.677) reveals (0.675) did (0.672)
sweden:
 italy (0.775) greece (0.767) ireland (0.751) france (0.729) spain (0.728)
1979:
 1977 (0.883) 1959 (0.873) 1984 (0.871) 1969 (0.861) 1973 (0.859)
write:
 publish (0.706) writes (0.689) learned (0.687) writing (0.686) readers (0.677)
gothenburg:
 auburn (0.572) kaiserslautern (0.559) palmdale (0.549) vancouver (0.549) geauga (0.543)
------------------------------------
Pairs: 99614720, words: 28305703, loss: 0.4073, time: 5.26
apple:
 video (0.715) hybrid (0.707) chip (0.706) hardware (0.701) cameras (0.689)
terrible:
 spies (0.685) hatred (0.684) danger (0.674) sudden (0.673) rise (0.657)
sweden:
 italy (0.792) greece (0.754) spain (0.735) ireland (0.734) germany (0.730)
1979:
 1977 (0.885) 1959 (0.877) 1984 (0.874) 1983 (0.870) 1996 (0.868)
write:
 publish (0.727) writes (0.686) readers (0.684) writing (0.676) learned (0.669)
gothenburg:
 palmdale (0.566) auburn (0.565) kaiserslautern (0.549) helicarrier (0.542) geauga (0.538)
------------------------------------
Pairs: 104857600, words: 29794372, loss: 0.4056, time: 5.23
apple:
 video (0.734) hardware (0.732) hybrid (0.714) chip (0.707) microsoft (0.699)
terrible:
 danger (0.693) sudden (0.685) fate (0.652) spies (0.651) fall (0.648)
sweden:
 italy (0.806) germany (0.773) greece (0.771) denmark (0.760) austria (0.751)
1979:
 1977 (0.899) 1969 (0.886) 1976 (0.884) 1983 (0.883) 1973 (0.883)
write:
 publish (0.717) readers (0.698) discuss (0.694) writes (0.686) writing (0.680)
gothenburg:
 palmdale (0.575) auburn (0.559) helicarrier (0.539) kaiserslautern (0.539) linn (0.538)
------------------------------------
Pairs: 110100480, words: 31287751, loss: 0.4048, time: 5.29
apple:
 hardware (0.768) video (0.754) interactive (0.699) macintosh (0.696) computer (0.695)
terrible:
 sudden (0.725) danger (0.691) dying (0.645) fall (0.644) after (0.634)
sweden:
 denmark (0.817) italy (0.810) germany (0.804) greece (0.778) hungary (0.777)
1979:
 1977 (0.901) 1987 (0.900) 1974 (0.892) 1975 (0.891) 1973 (0.889)
write:
 readers (0.703) publish (0.696) how (0.687) discuss (0.677) writes (0.673)
gothenburg:
 palmdale (0.581) auburn (0.555) geauga (0.547) shafer (0.539) thomasville (0.538)
------------------------------------
Pairs: 115343360, words: 32774622, loss: 0.4032, time: 5.09
apple:
 hardware (0.761) video (0.748) macintosh (0.686) interactive (0.679) computer (0.678)
terrible:
 sudden (0.750) danger (0.672) dying (0.663) victim (0.659) fall (0.653)
sweden:
 denmark (0.815) germany (0.800) italy (0.792) netherlands (0.788) austria (0.784)
1979:
 1987 (0.909) 1973 (0.906) 1977 (0.903) 1975 (0.902) 1974 (0.898)
write:
 learn (0.687) knowing (0.682) how (0.679) writing (0.678) wants (0.675)
gothenburg:
 palmdale (0.584) geauga (0.555) auburn (0.546) shafer (0.545) thomasville (0.538)
------------------------------------
Pairs: 120586240, words: 34263801, loss: 0.4024, time: 5.22
apple:
 video (0.733) hardware (0.731) computer (0.691) vintage (0.690) google (0.678)
terrible:
 sudden (0.740) because (0.699) victim (0.691) dying (0.682) prolonged (0.681)
sweden:
 denmark (0.808) netherlands (0.797) germany (0.796) austria (0.787) italy (0.780)
1979:
 1973 (0.918) 1977 (0.907) 1975 (0.905) 1974 (0.899) 1987 (0.896)
write:
 writing (0.711) learn (0.697) read (0.695) how (0.665) tell (0.658)
gothenburg:
 palmdale (0.583) geauga (0.562) shafer (0.543) thomasville (0.540) auburn (0.536)
------------------------------------
Pairs: 125829120, words: 35753906, loss: 0.4014, time: 5.15
apple:
 video (0.709) hardware (0.692) google (0.688) mpeg (0.672) computer (0.670)
terrible:
 victim (0.717) sudden (0.711) because (0.710) fight (0.689) prolonged (0.682)
sweden:
 denmark (0.799) norway (0.776) germany (0.775) austria (0.771) netherlands (0.767)
1979:
 1977 (0.911) 1973 (0.907) 1974 (0.903) 1975 (0.898) 1985 (0.895)
write:
 read (0.716) writing (0.712) learn (0.681) tell (0.672) lyrics (0.668)
gothenburg:
 palmdale (0.581) geauga (0.557) thomasville (0.552) mennonite (0.534) bantamweight (0.531)
------------------------------------
Pairs: 131072000, words: 37242042, loss: 0.4004, time: 5.13
apple:
 mpeg (0.675) google (0.660) video (0.655) hardware (0.653) macintosh (0.648)
terrible:
 victim (0.726) because (0.686) sudden (0.685) fate (0.683) demon (0.681)
sweden:
 denmark (0.805) germany (0.766) finland (0.764) norway (0.762) austria (0.757)
1979:
 1977 (0.912) 1974 (0.908) 1971 (0.903) 1976 (0.899) 1969 (0.895)
write:
 read (0.732) writing (0.699) listen (0.680) tell (0.670) lyrics (0.663)
gothenburg:
 palmdale (0.575) geauga (0.552) thomasville (0.551) mennonite (0.546) riverhead (0.544)
------------------------------------
Pairs: 136314880, words: 38731863, loss: 0.4000, time: 5.18
apple:
 hardware (0.659) mpeg (0.659) linux (0.650) computers (0.631) amiga (0.630)
terrible:
 victim (0.699) death (0.698) sudden (0.697) fate (0.693) believing (0.680)
sweden:
 denmark (0.820) norway (0.778) finland (0.776) germany (0.769) austria (0.765)
1979:
 1977 (0.920) 1982 (0.904) 1976 (0.904) 1974 (0.902) 1975 (0.893)
write:
 read (0.736) listen (0.729) writing (0.693) how (0.686) teach (0.683)
gothenburg:
 geauga (0.566) thomasville (0.559) riverhead (0.551) palmdale (0.550) mennonite (0.548)
------------------------------------
Pairs: 141557760, words: 40222659, loss: 0.3991, time: 5.11
apple:
 linux (0.648) mpeg (0.640) hardware (0.639) desktop (0.632) store (0.630)
terrible:
 sudden (0.716) death (0.684) danger (0.666) victim (0.664) witnessed (0.661)
sweden:
 denmark (0.788) germany (0.785) norway (0.784) netherlands (0.766) austria (0.758)
1979:
 1977 (0.920) 1982 (0.906) 1974 (0.905) 1993 (0.899) 1969 (0.898)
write:
 listen (0.727) read (0.710) publish (0.689) tell (0.684) how (0.677)
gothenburg:
 thomasville (0.571) geauga (0.567) crowder (0.551) bantamweight (0.546) waverley (0.546)
------------------------------------
Pairs: 146800640, words: 41714422, loss: 0.3992, time: 5.20
apple:
 proprietary (0.655) linux (0.631) mpeg (0.630) desktop (0.624) hardware (0.620)
terrible:
 sudden (0.725) bitter (0.679) death (0.661) dies (0.661) habit (0.660)
sweden:
 netherlands (0.789) denmark (0.784) hungary (0.783) norway (0.772) finland (0.772)
1979:
 1982 (0.910) 1977 (0.909) 1974 (0.909) 1970 (0.909) 1993 (0.903)
write:
 publish (0.703) listen (0.703) follow (0.691) read (0.691) tell (0.690)
gothenburg:
 thomasville (0.576) bantamweight (0.563) geauga (0.555) crowder (0.550) waverley (0.549)
------------------------------------
Pairs: 152043520, words: 43203227, loss: 0.3986, time: 5.26
apple:
 proprietary (0.668) desktop (0.652) nintendo (0.650) graphical (0.647) hardware (0.646)
terrible:
 sudden (0.703) bitter (0.668) aftermath (0.653) violent (0.650) darkness (0.647)
sweden:
 norway (0.804) netherlands (0.804) denmark (0.798) hungary (0.789) finland (0.783)
1979:
 1974 (0.909) 1970 (0.909) 1982 (0.907) 1983 (0.906) 1976 (0.897)
write:
 tell (0.697) publish (0.694) follow (0.692) listen (0.688) read (0.685)
gothenburg:
 thomasville (0.559) geauga (0.557) bantamweight (0.554) 75th (0.549) fiba (0.543)
------------------------------------
Pairs: 157286400, words: 44692069, loss: 0.3975, time: 5.11
apple:
 desktop (0.678) nintendo (0.672) toy (0.663) graphical (0.661) proprietary (0.661)
terrible:
 sudden (0.720) darkness (0.683) panic (0.667) deadly (0.664) aftermath (0.658)
sweden:
 denmark (0.806) netherlands (0.798) hungary (0.798) belgium (0.794) norway (0.790)
1979:
 1983 (0.916) 1974 (0.904) 1970 (0.904) 1982 (0.893) 1971 (0.892)
write:
 writing (0.691) read (0.680) tell (0.678) speak (0.677) publish (0.670)
gothenburg:
 geauga (0.559) bantamweight (0.554) fiba (0.546) thomasville (0.546) palmdale (0.545)
------------------------------------
Pairs: 162529280, words: 46179590, loss: 0.3978, time: 5.09
apple:
 computer (0.689) software (0.683) toy (0.677) handheld (0.675) macintosh (0.675)
terrible:
 panic (0.736) sudden (0.715) witnessed (0.686) darkness (0.679) dies (0.665)
sweden:
 denmark (0.828) belgium (0.820) germany (0.803) netherlands (0.796) austria (0.794)
1979:
 1983 (0.914) 1971 (0.899) 1969 (0.897) 1970 (0.896) 1974 (0.896)
write:
 read (0.678) let (0.664) tell (0.658) writing (0.652) ask (0.638)
gothenburg:
 bantamweight (0.561) fiba (0.554) geauga (0.553) thomasville (0.545) sfo (0.544)
------------------------------------
Pairs: 167772160, words: 47669494, loss: 0.3971, time: 4.87
apple:
 macintosh (0.702) software (0.684) computer (0.681) toy (0.679) handheld (0.678)
terrible:
 panic (0.766) witnessed (0.697) sudden (0.685) fear (0.671) darkness (0.666)
sweden:
 denmark (0.832) belgium (0.822) germany (0.807) austria (0.785) finland (0.779)
1979:
 1983 (0.918) 1970 (0.912) 1969 (0.911) 1973 (0.903) 1977 (0.901)
write:
 read (0.682) tell (0.663) let (0.653) writing (0.643) call (0.638)
gothenburg:
 bantamweight (0.563) fiba (0.554) sfo (0.554) nordic (0.550) geauga (0.547)
------------------------------------
Pairs: 173015040, words: 49158007, loss: 0.3964, time: 5.08
apple:
 macintosh (0.693) computer (0.689) toy (0.683) portable (0.671) handheld (0.669)
terrible:
 panic (0.759) fear (0.677) realizing (0.663) sudden (0.662) despair (0.655)
sweden:
 denmark (0.823) germany (0.818) belgium (0.815) finland (0.800) norway (0.791)
1979:
 1970 (0.919) 1969 (0.918) 1983 (0.908) 1973 (0.905) 1981 (0.903)
write:
 read (0.714) tell (0.658) listen (0.657) ask (0.656) say (0.645)
gothenburg:
 bantamweight (0.563) sfo (0.558) nordic (0.547) fiba (0.547) chachapoyas (0.545)
------------------------------------
Pairs: 178257920, words: 50647554, loss: 0.3964, time: 5.00
apple:
 macintosh (0.691) computer (0.683) graphics (0.681) software (0.677) package (0.665)
terrible:
 fear (0.691) panic (0.688) guilt (0.665) desperate (0.663) isn (0.662)
sweden:
 germany (0.801) belgium (0.798) norway (0.797) hungary (0.790) finland (0.786)
1979:
 1970 (0.911) 1988 (0.908) 1976 (0.904) 1969 (0.903) 1973 (0.900)
write:
 read (0.739) writing (0.678) listen (0.672) tell (0.671) ask (0.666)
gothenburg:
 sfo (0.543) chachapoyas (0.543) bantamweight (0.543) flashpoint (0.541) nordic (0.533)
------------------------------------

Inspecting the result

In addition to the examples above, we can also inspect the results interactively. For instance, we can show the nearest-neighbor lists (according to the cosine similarity) of some other test words.

In [8]:
model.nearest_neighbors(['garlic'], 5)
Out[8]:
[[('onion', 0.6637248992919922),
  ('flour', 0.6544929146766663),
  ('baked', 0.6539584994316101),
  ('fried', 0.6472417712211609),
  ('fruit', 0.6423918008804321)]]

We can also show the cosine similarity for a given pair of words. Apparently, a dog is distributionally more similar to a cat than to a gorilla.

In [9]:
model.cosine_similarity('dog', 'cat')
Out[9]:
0.7240018248558044
In [10]:
model.cosine_similarity('dog', 'gorilla')
Out[10]:
0.5336264371871948