Machine Learning in Swift: Document Classification with Core ML

It's an exciting time for Swift developers to get into machine learning. With Apple's new Core ML framework, support on all four platforms, and a growing library of models available for the framework, incorporating machine learning solutions in Swift apps has never been easier.

In this tutorial we'll use a Core ML model to classify news articles into one of five categories: Business, Entertainment, Politics, Sports, and Technology. The finished project will be the open-source framework availave here.

Classification Problems

The goal in solving a text classification problem is to train a computer to take a string of text it has never seen before and accurately predict which category it belongs in. The document classification problem we are solving here is one example. Others include sentiment analysis (whether the text is positive or negative) and spam detection. So how do you train a computer to tell you whether a sentence is positive or negative, spam or trusted, business or tech?

First, consider how you would classify the following paragraph (as Business, Entertainment, Politics, Sports, or Technology):

The iPhone maker announced that it would be shortening its name from "Apple Computer, Inc." to “Apple, Inc.” in recognition that the company has moved beyond the Mac. The tech giant is listed on the NASDAQ under the symbol AAPL. Investors will be looking to see continued gains in the company’s newer product categories on Tuesday’s earnings call.

Most of us would put this in the "Business" category. It's about a tech company though and has the words "Apple", "iPhone", "computer", and "tech" in it. When we read it, we understand that "NASDAQ" is a stock exchange, "symbol" refers to "ticker symbol", and "earnings call" is a quarterly event in which financial results are disclosed. We know that investing is the dominant theme here.

The model we'll be using in this tutorial accurately predicts that this paragraph is about "Business". It tells us there's a 22.2% chance it's about "Business" and a 20.8% chance it's about "Technology" (remember there are 5 categories). But the computer doesn't know that "NASDAQ" is a stock exchange, what an "earnings call" is, and out of context the word "symbol" could have many different meanings.

Training The Model

It turns out that for this task the computer doesn't need to know that "NASDAQ" is a stock exchange or understand context. It just needs to know that "NASDAQ" and "earnings" are strongly associated with "Business", and that either there are more words strongly associated with "Business" than "Technology" in this paragraph or that the words associated with "Business" have a stronger association with that category than those associated with "Technology".

Since the computer doesn't yet know anything about our five categories, we need to train a model to associate words with categories using data that humans have pre-classified. The model we are using here was trained with 1,500 news articles from the BBC, with 300 assigned to each category. To compute associations between words and categories and train the model, we can take each article and count how many times each word in the article appears.

Once the model has been trained with pre-classified data, we can generate word counts for a string of text the computer has never seen before and compare it to the model's data. The model will then give us a probability for each of the categories it knows about, with the highest being its prediction.

Tokenization

Our model has been trained to accept a [String: Double] dictionary, with the key being a word and the value the number of times it appears in the input text. To generate this dictionary, we need to extract all of the words from the input and count the occurrences of each. The process of breaking down the text into words is called tokenization.

We can use Apple's NSLinguisticTagger class to do this. First, create an Xcode project (a single view iOS app will work fine) and add a new Swift class called "DocumentClassifier". We'll start by configuring the tagger.

final class DocumentClassifier {

    let options: NSLinguisticTagger.Options = [.omitWhitespace, .omitPunctuation, .omitOther]
    lazy var tagger: NSLinguisticTagger = {
        let tagSchemes = NSLinguisticTagger.availableTagSchemes(forLanguage: "en")
        return NSLinguisticTagger(tagSchemes: tagSchemes, options: Int(self.options.rawValue))
    }()

}

NSLinguisticTagger is a feature-rich class with lots of natural language processing functionality, and not all of its capabilities are available in all languages or on all platforms. So we begin by specifying that we want the available tag schemes for English. Then we pass in options telling the tagger we want to omit whitespace, punctuation, and tags in the "other" category (e.g. symbols). These tokens are not helpful for our document classification task.

Next let's use the tagger to enumerate the tokens in an input string and generate the word counts dictionary we'll need to pass to the model.

func extractFeatures(from text: String) -> [String: Double] {
    var wordCounts = [String: Double]()
    tagger.string = text
    let range = NSRange(location: 0, length: text.count)
    tagger.enumerateTags(in: range, scheme: .tokenType, options: options) { _, tokenRange, _, _ in
        let token = (text as NSString).substring(with: tokenRange).lowercased()
        guard token.count >= 3 else { return }
        guard let value = wordCounts[token] else {
            wordCounts[token] = 1.0
            return
        }
        wordCounts[token] = value + 1.0
    }
    return wordCounts
}

The important piece is where we call the enumerateTags function on NSLinguisticTagger. We pass in a block that gives us the range of each token in the input, which allows us to get the actual word token. We also pass in the options we used above for omitting whitespace, punctuation, and other tokens and in the block disregard words shorter than 3 characters. Then we check if the word already exists in the word counts dictionary. If so, we increment its count value. Otherwise we set its value to 1.0.

Using The Model

We're now finally ready to use the Core ML model. You can download it here. Add the model to your project. If you look at the model file in Xcode, you'll see its defined inputs and outputs in the "Model Evaluation Parameters" section. As expected, the model takes a [String: Double] input. Its output will be the string value of the category it predicts for the text and a dictionary with the probabilities for each category.

Let's continue by modeling the output. Add a new file with the following Classification struct:

struct Classification {

    enum Category: String {
        case business = "Business"
        case entertainment = "Entertainment"
        case politics = "Politics"
        case sports = "Sports"
        case technology = "Technology"
    }

    struct Result {
        public let category: Category
        public let probability: Double
    }

    let prediction: Result
    let allResults: [Result]

}

Here we define two embedded types, one for enumerating the possible document categories and another which defines a result value, which contains a category and a probability. Now we can easily expose the model's prediction and the probabilities for all categories in terms of Result values.

Next we need to be able to convert our model's output to a Classification value. Core ML autogenerates Swift classes for the model which we can use to do this. If you build your app and go to the model file in Xcode, you'll see a section called "Model Class". This will contain the model class name with an arrow next to it that links to the generated classes. You should see three classes, DocumentClassification (the model class), DocumentClassificationInput, and DocumentClassificationOutput.

Let's add an initializer for Classification that takes a DocumentClassificationOutput object.

extension Classification {

    init?(output: DocumentClassificationOutput) {
        guard let category = Category(rawValue: output.classLabel),
            let probability = output.classProbability[output.classLabel]
            else { return nil }
        let prediction = Result(category: category, probability: probability)
        let allResults = output.classProbability.flatMap(Classification.result)
        self.init(prediction: prediction, allResults: allResults)
    }

    static func result(from classProbability: (key: String, value: Double)) -> Result? {
        guard let category = Category(rawValue: classProbability.key) else { return nil }
        return Result(category: category, probability: classProbability.value)
    }
}

We first get the prediction using the classLabel property for the category and finding its probability in the classProbability dictionary. Then we flatMap the classProbability dictionary to an array of result values to get allResults.

All that is left now is to add a property for our model and write a function that will use it to classify an input string. Add the model property at the top of the DocumentClassifier class:

let model = DocumentClassification()

Finally, let's add our classify function:

func classify(_ text: String) -> Classification? {
    let features = extractFeatures(from: text)
    guard
        features.count > 2,
        let output = try? model.prediction(input: features) else { return nil }
    return Classification(output: output)
}

We use the extractFeatures function from earlier to generate the word counts dictionary and check that we have at least 3 tokens before passing the input dictionary into the model. If we get a valid output, we use the initializer we just wrote above to return a Classification value.

Testing

To test the model, first add a DocumentClassifier property to the top of your app's main view controller.

let classifier = DocumentClassifier()

Now add a constant for an input string. It can be whatever you want, or you can use the sample string from the "Classification Problems" section above. Then add the following in viewDidLoad to see the results:

guard let classification = classifier.classify(text) else { return }
print(classification.prediction)
print(classification.allResults)

Core ML is a powerful new framework that makes accomplishing difficult machine learning and natural language processing tasks in Swift significantly easier. In this tutorial, we saw how we can use the framework and a Core ML model to write a document classifier with relatively little code. Adding machine learning to your apps is a great way to enrich the user experience and make your app stand out. Using Core ML to process user data entirely on the device also presents opportunities to improve security and protect user privacy.

Next Steps