Text classification using a character-based convolutional neural network

As our third example, we will replicate the system described by Zhang et al. (2015), which uses a CNN based on characters instead of words.

I wasn't able to get accuracies that are as good as those we saw for the word-based CNN and the CBoW classifier. Anyway, it can still be useful to study this example as a comparison.

In [1]:
import torch
from torch import nn
import time
import torchtext

from collections import defaultdict

import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina' 
plt.style.use('seaborn')

Defining the character-based CNN

We follow the system description given by Zhang et al. (2015). The model structure is identical to theirs.

As usual, look into the code for more detailed comments.

In [2]:
class CharCNNTextClassifier(nn.Module):
    def __init__(self, text_field, class_field):
        super().__init__()
        self.voc_size = len(text_field.vocab)
        n_classes = len(class_field.vocab)
        
        n_channels = 256
        dropout_prob = 0.5
        fc_size = 1024
        
        # The model by Zhang et al. is applied to one-hot encoded characters. Character-based
        # models that have been proposed more recently have used an embedding layer to represent
        # the characters, but we'll just stick with the one-hot encoding here.
        
        # We first define the stack of convolutional and pooling layers, following the
        # description by Zhang exactly.
        # We use a Sequential, which is a container for objects of the type nn.Module.
        # They will be applied serially, where the output of one step will be fed as the
        # input to the next step in the sequence.
        self.conv_stack = nn.Sequential(
            nn.Conv1d(in_channels=self.voc_size, out_channels=n_channels, kernel_size=7),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3),
            nn.Conv1d(in_channels=n_channels, out_channels=n_channels, kernel_size=7),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3),
            nn.Conv1d(in_channels=n_channels, out_channels=n_channels, kernel_size=3),
            nn.ReLU(),
            nn.Conv1d(in_channels=n_channels, out_channels=n_channels, kernel_size=3),
            nn.ReLU(),
            nn.Conv1d(in_channels=n_channels, out_channels=n_channels, kernel_size=3),
            nn.ReLU(),
            nn.Conv1d(in_channels=n_channels, out_channels=n_channels, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3),
        )
        
        # We define another Sequential stack for the fully connected (fc) part of the network.        
        # The size of the input will be 256*34 = 8704 because of the structure of the conv_stack
        # network, and because the input to conv_stack has a fixed size of 1014 characters.        
        n_in = 256*34
        self.fc = nn.Sequential(
            nn.Linear(in_features=n_in, out_features=fc_size),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(in_features=fc_size, out_features=fc_size),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(in_features=fc_size, out_features=n_classes)
        )
        
        
    def onehot_encode(self, texts):
        # A helper function to do the one-hot encoding of characters.
        sen_len, batch_size = texts.shape
        out = torch.zeros(size=(sen_len, batch_size, self.voc_size), device=texts.device)
        out.scatter_(2, texts.view(sen_len, batch_size, 1), 1)
        return out.permute(1, 2, 0)
        
        
    def forward(self, texts):
        # One-hot-encode the sequences of characters.
        onehot = self.onehot_encode(texts)

        # Apply the convolution stack. Because the size of the input is fixed to 1014 characters,
        # and because we set the number of output channels to 256, 
        # the shape will now be (batch size, 256, 34)
        conv = self.conv_stack(onehot)

        # We view the result as a tensor of shape (batch_size, 256*34)
        # so that it fits the input shape of the fully connected layer.
        conv = conv.view(conv.shape[0], -1)

        # Finally apply the fully connected layer, and then return the output.
        scores = self.fc(conv)
        return scores

Training the classifier

Again, we use the same training setup as in the previous two notebooks (for CBoW and word-based CNN).

The code is almost identical, with some notable differences:

  • the torchtext Field representing the texts will produce a character sequence instead of a word sequence,
  • the learning rate is decreased gradually, as described in Zhang's paper (but we use Adam instead of SGD).

I had some difficulties to get this model to learn anything and despite quite a bit of tweaking, I never got the accuracy to more than about 0.81, compared to 0.85-0.86 that we typically see for CBoW and word-based CNN.

In [3]:
def read_data(corpus_file, datafields, label_column, doc_start):
    with open(corpus_file, encoding='utf-8') as f:
        examples = []
        for line in f:
            columns = line.strip().split(maxsplit=doc_start)
            doc = columns[-1]
            label = columns[label_column]
            examples.append(torchtext.data.Example.fromlist([doc, label], datafields))
    return torchtext.data.Dataset(examples, datafields)

def evaluate_validation(scores, loss_function, gold):
    guesses = scores.argmax(dim=1)
    n_correct = (guesses == gold).sum().item()
    return n_correct, loss_function(scores, gold).item()

def main():
    # NOTE that the tokenization is done differently here compared to the previous examples.
    # The output of this tokenization will be a sequence of 1014 characters.
    TEXT = torchtext.data.Field(sequential=True, tokenize=list, fix_length=1014)
    LABEL = torchtext.data.LabelField(is_target=True)
    datafields = [('text', TEXT), ('label', LABEL)]
    
    corpus = 'amazon'
    if corpus == 'amazon':
        data = read_data('data/all_sentiment_shuffled.txt', datafields, label_column=1, doc_start=3)
        train, valid = data.split([0.8, 0.2])
    elif corpus == 'ag':
        train = read_data('data/ag_news.train', datafields, label_column=0, doc_start=2) 
        valid = read_data('data/ag_news.test', datafields, label_column=0, doc_start=2) 
        
    TEXT.build_vocab(train, max_size=10000)
    LABEL.build_vocab(train)
    
    model = CharCNNTextClassifier(TEXT, LABEL)

    device = 'cuda'
    model.to(device)
    
    train_iterator = torchtext.data.BucketIterator(
        train,
        device=device,
        batch_size=128,
        sort_key=lambda x: len(x.text),
        repeat=False,
        train=True)
    
    valid_iterator = torchtext.data.Iterator(
        valid,
        device=device,
        batch_size=128,
        repeat=False,
        train=False,
        sort=False)

    loss_function = torch.nn.CrossEntropyLoss()
    learning_rate = 0.0005
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    print(f'Setting the learning rate to {learning_rate}.')
    
    train_batches = list(train_iterator)
    valid_batches = list(valid_iterator)
    
    history = defaultdict(list)
    
    for i in range(1, 31):
        
        t0 = time.time()
        
        loss_sum = 0
        n_batches = 0

        model.train()
        
        for batch in train_batches:
            scores = model(batch.text)
            loss = loss_function(scores, batch.label)

            optimizer.zero_grad()            
            loss.backward()
            optimizer.step()
    
            loss_sum += loss.item()
            n_batches += 1
        
        train_loss = loss_sum / n_batches
        history['train_loss'].append(train_loss)
        
        n_correct = 0
        n_valid = len(valid)
        loss_sum = 0
        n_batches = 0

        model.eval()
        
        for batch in valid_batches:
            scores = model(batch.text)
            n_corr_batch, loss_batch = evaluate_validation(scores, loss_function, batch.label)
            loss_sum += loss_batch
            n_correct += n_corr_batch
            n_batches += 1
        val_acc = n_correct / n_valid
        val_loss = loss_sum / n_batches

        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)        
        
        t1 = time.time()
        print(f'Epoch {i}: train loss = {train_loss:.4f}, val loss = {val_loss:.4f}, val acc: {val_acc:.4f}, time = {t1-t0:.4f}')

        if i % 5 == 0:
            learning_rate *= 0.5
            print(f'Setting the learning rate to {learning_rate}.')
            for g in optimizer.param_groups:
                g['lr'] = learning_rate            
            
    plt.plot(history['train_loss'])
    plt.plot(history['val_loss'])
    plt.plot(history['val_acc'])
    plt.legend(['training loss', 'validation loss', 'validation accuracy'])
    
    
main()
Setting the learning rate to 0.0005.
Epoch 1: train loss = 0.6936, val loss = 0.6935, val acc: 0.4889, time = 6.3469
Epoch 2: train loss = 0.6917, val loss = 0.6907, val acc: 0.5376, time = 6.1841
Epoch 3: train loss = 0.6651, val loss = 0.6545, val acc: 0.6530, time = 6.2011
Epoch 4: train loss = 0.6058, val loss = 0.5997, val acc: 0.6693, time = 6.2503
Epoch 5: train loss = 0.5335, val loss = 0.5258, val acc: 0.7398, time = 6.2551
Setting the learning rate to 0.00025.
Epoch 6: train loss = 0.4339, val loss = 0.4749, val acc: 0.7768, time = 6.2590
Epoch 7: train loss = 0.3978, val loss = 0.4584, val acc: 0.7881, time = 6.2691
Epoch 8: train loss = 0.3475, val loss = 0.5166, val acc: 0.7482, time = 6.8954
Epoch 9: train loss = 0.2861, val loss = 0.4536, val acc: 0.8023, time = 6.9594
Epoch 10: train loss = 0.2536, val loss = 0.4704, val acc: 0.8070, time = 6.9581
Setting the learning rate to 0.000125.
Epoch 11: train loss = 0.2125, val loss = 0.4650, val acc: 0.8154, time = 6.9565
Epoch 12: train loss = 0.1825, val loss = 0.5313, val acc: 0.8137, time = 6.9556
Epoch 13: train loss = 0.1464, val loss = 0.5848, val acc: 0.8212, time = 6.9578
Epoch 14: train loss = 0.1274, val loss = 0.6998, val acc: 0.7944, time = 6.9592
Epoch 15: train loss = 0.1287, val loss = 0.6520, val acc: 0.8091, time = 6.9600
Setting the learning rate to 6.25e-05.
Epoch 16: train loss = 0.0869, val loss = 0.7602, val acc: 0.8112, time = 6.9599
Epoch 17: train loss = 0.0753, val loss = 0.8614, val acc: 0.8032, time = 6.9580
Epoch 18: train loss = 0.0663, val loss = 0.9758, val acc: 0.7965, time = 6.9583
Epoch 19: train loss = 0.0614, val loss = 0.9389, val acc: 0.8040, time = 6.9575
Epoch 20: train loss = 0.0624, val loss = 0.9861, val acc: 0.8057, time = 6.9599
Setting the learning rate to 3.125e-05.
Epoch 21: train loss = 0.0417, val loss = 0.8273, val acc: 0.8196, time = 6.9630
Epoch 22: train loss = 0.0302, val loss = 0.8684, val acc: 0.8179, time = 6.9668
Epoch 23: train loss = 0.0270, val loss = 0.8964, val acc: 0.8175, time = 6.9666
Epoch 24: train loss = 0.0220, val loss = 0.9450, val acc: 0.8145, time = 6.9657
Epoch 25: train loss = 0.0188, val loss = 0.9949, val acc: 0.8154, time = 6.9643
Setting the learning rate to 1.5625e-05.
Epoch 26: train loss = 0.0167, val loss = 0.9953, val acc: 0.8162, time = 6.9661
Epoch 27: train loss = 0.0154, val loss = 1.0160, val acc: 0.8133, time = 6.9674
Epoch 28: train loss = 0.0146, val loss = 1.0286, val acc: 0.8137, time = 6.9651
Epoch 29: train loss = 0.0135, val loss = 1.0379, val acc: 0.8141, time = 6.9649
Epoch 30: train loss = 0.0136, val loss = 1.0496, val acc: 0.8149, time = 6.9657
Setting the learning rate to 7.8125e-06.