A Random Walk With Drift.

Paper: Accelerating Inference for Sparse XMR Trees

Cover Image for Paper: Accelerating Inference for Sparse XMR Trees
Philip A. Etter
Philip A. Etter

Published in the Proceedings of the Web Conference 2022

Paper Link Preprint

Introduction

In this paper we study how to accelerate inference for sparse extreme multi-label ranking (XMR) tree models. Tree-based models are the workhorse of many modern search engines and recommender systems. Tree-based models have many advantages over their neural network counterparts. First, they are extremely fast to train: many models can be trained in minutes as opposed to days for neural networks. They are also very quick when it comes to inference, as the bulk of the inference consists of sparse linear algebra multiplications. Furthermore, they also take up substantially less memory. In enterprise-scale applications, these benefits mean tremendous savings in compute and storage. However, there is always room for improvement. In particular, while the scientific community has devoted much attention to the optimization of dense linear algebra kernels for neural networks, they have devoted significantly less attention to the optimization of sparse linear algebra kernels for tree-based models. In order to fill this gap in the literature, we will propose a class of sparse linear algebra kernels that we call masked sparse chunk multiplication (MSCM) techniques. These kernels are tailored specifically for XMR tree models and drastically improve performance anywhere from x8 to x20 on our benchmark models. The technique drives inference for Amazon's Predictions for Enormous and Correlated Output Spaces (PECOS) model framework, where it has significantly reduced the cost of both online and batch inference.

Extreme Multi-Label Ranking Trees

An XMR problem can be characterized as follows: given a query from some embedding and a set of labels , produce a model that gives an (implicit) ranking of the relevance of the labels in to query . In addition, for any query , one must be able to efficiently retrieve the top most relevant labels in according to the model --- noting that is typically very large and very sparse.

A linear XMR tree model is a hierarchical linear model that constructs a hierarchical clustering of the labels , forming a tree structure. These clusters are denoted , where denotes the depth (i.e., layer) of in the model tree and denotes the index of in that layer. The leaves of the tree are the individual labels of .

Every layer of the model has a ranker model that scores the relevance of a cluster to a query . This ranker model may take on different forms, but for this paper we assume that the model is logistic-like. This means that, at the second layer, for example, the relevance of a cluster is given by

where denotes an activation function (e.g., sigmoid) and denotes a very sparse vector of weight parameters.

At subsequent layers, rankers are composed with those of previous layers, mimicking the notion of conditional probability; hence the score of a cluster is defined by the model as

where denotes all tree nodes on the path from to the root (including and excluding ). Naturally, this definition extends all the way to the individual labels at the bottom of the tree. We assume here for simplicity that the leaves of the tree all occur on the same layer, but the techniques described in this paper can be extended to the general case.

As a practical aside, the column weight vectors for each layer are stored in a weight matrix

where denotes the number of clusters in layer . The tree topology at layer is usually represented using a cluster indicator matrix . is defined as

i.e., it is one when is a child of in the tree. Here, is when the condition is true and otherwise.

Inference and Beam Search

In general, there are two different inference settings:

  1. Batch Inference: inference is performed for a batch of queries represented by a sparse matrix where every row of is an individual query .
  2. Online Inference: a subset of the batch setting where there is only one query, e.g., the matrix has only one row.

When performing inference, the XMR model prescribes a score to all query-cluster pairs . Hence, in the batch setting, one can define the prediction matrices,

The act of batch inference entails collecting the top most relevant labels (leaves) for each query and returning their respective prediction scores .

However, the act of exact inference is typically intractable, as it requires searching the entire model tree. To sidestep this issue, models usually use a greedy beam search of the tree as an approximation. For a query , this approach discards any clusters on a given level that do not fall into the top most relevant clusters examined at that level. Hence, instead of , we compute beamed prediction matrices , where each row has only nonzero entries whose values are equal to their respective counterparts in .

Masked Sparse Chunk Multiplication

Our contribution is a method of evaluating masked sparse matrix multiplication that leverages the unique sparsity structure of the beam search to reduce unnecessary traversal, optimize memory locality, and minimize cache misses. The core prediction step of linear XMR tree models is the evaluation of a masked matrix product, i.e.,

where denotes ranker activations at layer , denotes a dynamic mask matrix determined by beam search, is a sparse matrix whose rows correspond to queries in the embedding space, is the sparse weight matrix of our tree model at layer , and denotes entry-wise multiplication.

Observations in the paper about the structure of the sparsity of and lead to the idea of the column chunked matrix data structure for the weight matrix . In this data structure, we store the matrix as a horizontal array of matrix chunks ,

where each chunk ( is the branching factor, i.e. number of children of the parent node) and is stored as a vertical sparse array of some sparse horizontal vectors ,

We identify each chunk with a parent node in layer of the model tree, and the columns of the chunk with the set of siblings in layer of the model tree that share the aforementioned parent node in layer .

To see why this data structure can accelerate the masked matrix multiplication, consider that one can think of the mask matrix as being composed of blocks,

where the block column partition is the same as that in the definition of the column chunked weight matrix , and every block has one row and corresponds to a single query. Furthermore, every block must either be composed entirely of zeros or entirely of ones.

Therefore, since and share the same sparsity pattern, the ranker activation matrix is also composed of the same block partition as ,

Hence, for all mask blocks that are , we have

where and denote the indices of the nonzero entries of and the nonzero rows of respectively. The above equation says that for all entries of in the same block, the intersection only needs to be iterated through once per chunk, as opposed to once per column as is done in a vanilla implementation. Moreover, the actual memory locations of the values actively participating in the product are physically closer in memory than they are when is stored in CSC format. This helps contribute to better memory locality.

We remark that it remains to specify how to efficiently iterate over the nonzero entries and nonzero rows for . This is essential for computing the vector-chunk product efficiently. There are number of ways to do this, each with potential benefits and drawbacks:

  1. Marching Pointers: The easiest method is to use a marching pointer scheme to iterate over and for .
  2. Binary Search: Since can be highly sparse, the second possibility is to do marching pointers, but instead of incrementing pointers one-by-one to find all intersections, we use binary search to quickly find the next intersection.
  3. Hash-Map: The third possibility is to enable fast random access to the rows of via a hash-map. The hash-map maps indices to nonzero rows of .
  4. Dense Lookup: The last possibility is to accelerate the above hash-map access by copying the contents of the hash-map into a dense array of length . Then, a random access to a row of is done by an array lookup. This consumes the most memory.

There is one final optimization that we have found particularly helpful in reducing inference time --- and that is evaluating the nonzero blocks in order of column chunk . Doing this ensures that a single chunk ideally only has to enter the cache once for all the nonzero blocks whose values depend on .

Performance Benchmarks

There are more thorough performance benchmarks in the paper. But for this blog post, I will simply show a side by side comparison of our Hash MSCM implementation against a Hash CSC implementation of an existing XMR library, NapkinXC. Since NapkinXC models can be run in PECOS (our library), we have an apples-to-apples performance comparison. We measure performance on several different data sets:

Note that thanks to MSCM, our implementation is around 10x faster than NapkinXC. See the full text for more exhaustive performance analysis.