25 Nov 2018

Pointer Networks

The paper can be found here

What is this about?

In layman terms, pointer networks is a class of networks developed from a simple modification to attention based neural models, that can be used to solve a class of combinatorial optimization problems such as sorting, finding the convex hull, or travelling salesman problem.

Key Contributions of such an approach are two folds:

  • a model whose output sequence is such that the elements are indexes (discrete tokens) corresponding to the input elements
  • Size of the output dictionary (elements that can be given as output) depends on the size of the input (no. of input elements)

Basically, a seq2seq model such as lstm would not have solved the problem because to output dictionary would be dependent on the no. of elements in the input and thus requiring a separate lstm to be trained for each n.

Formal Problem Definition & the Model Architecture

Problem Statement:

Given a set P, where set P is a sequence of n vectors, find a correct set Cp solving the problem where set Cp is a sequence of m(P) indices each between 1 and n

Example: Sorting

P = {2,3,8,1,5} (here P is sequence of 5 vectors of size 1)
Output: Cp = {4,1,2,5,3} - Indices in which the elements of P are considered sorted.

The above can be modelled using a encoder-decoder approach using LSTMs or any similar networks. Before we move on to the pointer networks, we first try to recall an attention based encoder-decoder model, since this will form the primary basis from which we can easily reach to pointer networks.

Traditional seq2seq models compute an attention vector for each element of the output. This attention vector is simply weight for each of the hidden state vectors produced by the encoder network that will be blend with the context vector to produce this output. Each element of this attention vector is computed as an MLP of the context vector (from the previous decoder) and the hidden vector under consideration (from an encoding unit), over which we take a softmax to produce a normalized distribution.

The softmax normalizes the attention vector to be an output distribution over the dictionary of the inputs. In the pointer networks, we do not blend the normalized attention vector with the context vector, but instead use those as pointers to the input element. So technically, pointer nets are a tweak which take out something from the traditional seq2seq models to attempt solving the problems in which the outputs correspond to the positions in the input.

Sample Applications

Sorting: While the authors don’t write about sorting, this is one of the simplest combinatorial optimization problem that comes to mind. There are several projects on the github that have implemented exactly this with great results.

Note: An interesting finding is that the order in which the network is fed with the input actually matters. So for the problems below, the inputs were feed in a specific order (when the order of input did not change the problem ofcourse)

Convex Hull: In general the solution to this problem has the complexity of O(nlogn). The pointer network model can be considered to be of complexity O(n^2) since for each output it goes over all the inputs (to create the attention vector). Authors used this problem as the baseline to develop the solution of pointer networks. Needless to say the results they come up with are promising.

Delaunay Traingulation: This is also a well understood problem and the exact solution again requires O(nlogn). While the paper claims good results for this too, the results have not been shared explicitly.

Travelling Salesman Problem: This is an interesting problem to talk about. While the previous ones were O(nlogn), this is one amongst the several NP-hard problems. Our so called O(n^2) pointer networks have shown to do quite well at this. In fact it does better than some of the other heuristic based algorithms and has done very competitively.

The exact same model was used and the hyper parameters were not fine tuned to fit extremely well in any of the above cases. This also demonstrates the generality of the model.

Some more interesting read on attention augumented models can be found here

Visitors: visitor counter