Added wasserstein_distance() dev
authorDmitriy Morozov <dmitriy@mrzv.org>
Thu, 10 May 2012 15:42:21 -0700
branchdev
changeset 247 ad3aefb5a0e0
parent 246 88f7806633e0
child 248 45589223382c
Added wasserstein_distance()
bindings/python/CMakeLists.txt
bindings/python/persistence-diagram.cpp
include/topology/persistence-diagram.h
include/topology/persistence-diagram.hpp
include/utilities/munkres/matrix.cpp
include/utilities/munkres/matrix.h
include/utilities/munkres/munkres.cpp
include/utilities/munkres/munkres.h
--- a/bindings/python/CMakeLists.txt	Thu May 10 14:24:40 2012 -0700
+++ b/bindings/python/CMakeLists.txt	Thu May 10 15:42:21 2012 -0700
@@ -18,6 +18,7 @@
                                                 cohomology-persistence.cpp
                                                 rips.cpp
                                                 distances.cpp
+                                                ../../include/utilities/munkres/munkres.cpp
                             )
 set                         (bindings_libraries ${libraries})
 
--- a/bindings/python/persistence-diagram.cpp	Thu May 10 14:24:40 2012 -0700
+++ b/bindings/python/persistence-diagram.cpp	Thu May 10 15:42:21 2012 -0700
@@ -223,4 +223,5 @@
                                       bp::arg("data")=bp::object()));
 
     bp::def("bottleneck_distance",  &bottleneck_distance_adapter);
+    bp::def("wasserstein_distance", &wasserstein_distance<dp::PersistenceDiagramD>);
 }
--- a/include/topology/persistence-diagram.h	Thu May 10 14:24:40 2012 -0700
+++ b/include/topology/persistence-diagram.h	Thu May 10 15:42:21 2012 -0700
@@ -169,6 +169,9 @@
 RealType                bottleneck_distance(const Diagram1& dgm1, const Diagram2& dgm2)
 { return bottleneck_distance(dgm1, dgm2, Linfty<typename Diagram1::Point, typename Diagram2::Point>()); }
 
+template<class Diagram>
+RealType                wasserstein_distance(const Diagram& dgm1, const Diagram& dgm2, unsigned p);
+
 
 #include "persistence-diagram.hpp"
 
--- a/include/topology/persistence-diagram.hpp	Thu May 10 14:24:40 2012 -0700
+++ b/include/topology/persistence-diagram.hpp	Thu May 10 15:42:21 2012 -0700
@@ -1,6 +1,8 @@
 #include <boost/serialization/vector.hpp>
 #include <boost/serialization/nvp.hpp>
 
+#include "utilities/munkres/munkres.h"
+
 using boost::serialization::make_nvp;
 
 template<class D>
@@ -19,7 +21,7 @@
 PersistenceDiagram(const PersistenceDiagram<OtherData>& other)
 {
     points_.reserve(other.size());
-    for (typename PersistenceDiagram<OtherData>::PointVector::const_iterator cur = points_.begin(); 
+    for (typename PersistenceDiagram<OtherData>::PointVector::const_iterator cur = points_.begin();
                                                                              cur != points_.end(); ++cur)
         push_back(Point(cur->x(), cur->y()));
 }
@@ -73,8 +75,8 @@
 
 template<class Diagrams, class Iterator, class Evaluator, class DimensionExtractor>
 void    init_diagrams(Diagrams& diagrams,
-                      Iterator bg, Iterator end, 
-                      const Evaluator& evaluator, 
+                      Iterator bg, Iterator end,
+                      const Evaluator& evaluator,
                       const DimensionExtractor& dimension)
 {
     // FIXME: this is specialized for Diagrams that is std::map
@@ -85,8 +87,8 @@
 
 template<class Diagrams, class Iterator, class Evaluator, class DimensionExtractor, class Visitor>
 void    init_diagrams(Diagrams& diagrams,
-                      Iterator bg, Iterator end, 
-                      const Evaluator& evaluator, 
+                      Iterator bg, Iterator end,
+                      const Evaluator& evaluator,
                       const DimensionExtractor& dimension,
                       const Visitor& visitor)
 {
@@ -114,18 +116,18 @@
 
 template<class D>
 template<class Archive>
-void 
+void
 PDPoint<D>::
 serialize(Archive& ar, version_type )
 {
-    ar & make_nvp("x", x()); 
-    ar & make_nvp("y", y()); 
+    ar & make_nvp("x", x());
+    ar & make_nvp("y", y());
     ar & make_nvp("data", data());
 }
 
 template<class D>
 template<class Archive>
-void 
+void
 PersistenceDiagram<D>::
 serialize(Archive& ar, version_type )
 {
@@ -134,7 +136,7 @@
 
 
 /**
- * Some structures to compute bottleneck distance between two persistence diagrams (in bottleneck_distance() function below) 
+ * Some structures to compute bottleneck distance between two persistence diagrams (in bottleneck_distance() function below)
  * by setting up bipartite graphs, and finding maximum cardinality matchings in them using Boost Graph Library.
  */
 #include <boost/iterator/counting_iterator.hpp>
@@ -172,7 +174,7 @@
 
         // FIXME: the matching is being recomputed from scratch every time, this should be fixed
         if (i2 > last)
-            do 
+            do
             {
                 ++last;
                 boost::add_edge(last->first, last->second, g);
@@ -238,10 +240,87 @@
 
     // Perform cardinality based binary search
     typedef boost::counting_iterator<EV_const_iterator>         EV_counting_const_iterator;
-    EV_counting_const_iterator bdistance = std::upper_bound(EV_counting_const_iterator(edges.begin()), 
-                                                            EV_counting_const_iterator(edges.end()), 
+    EV_counting_const_iterator bdistance = std::upper_bound(EV_counting_const_iterator(edges.begin()),
+                                                            EV_counting_const_iterator(edges.end()),
                                                             edges.begin(),
                                                             CardinaliyComparison(max_size, edges.begin()));
 
     return (*bdistance)->distance;
 }
+
+// Wasserstein distance
+template<class Diagram>
+RealType
+wasserstein_distance(const Diagram& dgm1, const Diagram& dgm2, unsigned p)
+{
+    typedef         RealType                    Distance;
+    typedef         typename Diagram::Point     Point;
+    typedef         Linfty<Point, Point>        Norm;
+
+    unsigned size = dgm1.size() + dgm2.size();
+    Norm norm;
+
+    // Setup the matrix
+    Matrix<Distance>        m(size,size);
+    for (unsigned i = 0; i < dgm1.size(); ++i)
+        for (unsigned j = 0; j < dgm2.size(); ++j)
+        {
+            const Point& p1 = *(dgm1.begin() + i);
+            const Point& p2 = *(dgm2.begin() + j);
+            m(i,j) = pow(norm(p1, p2),  p);
+            m(j + dgm1.size(), i + dgm2.size()) = 0;
+        }
+
+    for (unsigned i = 0; i < dgm1.size(); ++i)
+        for (unsigned j = dgm2.size(); j < size; ++j)
+        {
+            const Point& p1 = *(dgm1.begin() + i);
+            m(i,j) = pow(norm.diagonal(p1), p);
+        }
+
+    for (unsigned j = 0; j < dgm2.size(); ++j)
+        for (unsigned i = dgm1.size(); i < size; ++i)
+        {
+            const Point& p2 = *(dgm2.begin() + j);
+            m(i,j) = pow(norm.diagonal(p2), p);
+        }
+
+    // Compute weighted matching
+    Munkres munkres;
+    munkres.solve(m);
+
+    // Assume everything is assigned (i.e., that we have a perfect matching)
+    Distance sum = 0;
+    for (unsigned i = 0; i < size; i++)
+        for (unsigned j = 0; j < size; j++)
+            if (m(i,j) == 0)
+            {
+                //std::cout << i << ": " << j << '\n';
+                //sum += m[i][j];
+                if (i >= dgm1.size())
+                {
+                    if (j >= dgm2.size())
+                        sum += 0;
+                    else
+                    {
+                        const Point& p2 = *(dgm2.begin() + j);
+                        sum += pow(norm.diagonal(p2), p);
+                    }
+                } else
+                {
+                    if (j >= dgm2.size())
+                    {
+                        const Point& p1 = *(dgm1.begin() + i);
+                        sum += pow(norm.diagonal(p1), p);
+                    } else
+                    {
+                        const Point& p1 = *(dgm1.begin() + i);
+                        const Point& p2 = *(dgm2.begin() + j);
+                        sum += pow(norm(p1, p2),  p);
+                    }
+                }
+                break;
+            }
+
+    return sum;
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/include/utilities/munkres/matrix.cpp	Thu May 10 15:42:21 2012 -0700
@@ -0,0 +1,231 @@
+/*
+ *   Copyright (c) 2007 John Weaver
+ *
+ *   This program is free software; you can redistribute it and/or modify
+ *   it under the terms of the GNU General Public License as published by
+ *   the Free Software Foundation; either version 2 of the License, or
+ *   (at your option) any later version.
+ *
+ *   This program is distributed in the hope that it will be useful,
+ *   but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *   GNU General Public License for more details.
+ *
+ *   You should have received a copy of the GNU General Public License
+ *   along with this program; if not, write to the Free Software
+ *   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
+ */
+
+#include "matrix.h"
+
+#include <cassert>
+#include <cstdlib>
+#include <algorithm>
+
+/*export*/ template <class T>
+Matrix<T>::Matrix() {
+	m_rows = 0;
+	m_columns = 0;
+	m_matrix = NULL;
+}
+
+/*export*/ template <class T>
+Matrix<T>::Matrix(const Matrix<T> &other) {
+	if ( other.m_matrix != NULL ) {
+		// copy arrays
+		m_matrix = NULL;
+		resize(other.m_rows, other.m_columns);
+		for ( int i = 0 ; i < m_rows ; i++ )
+			for ( int j = 0 ; j < m_columns ; j++ )
+				m_matrix[i][j] = other.m_matrix[i][j];
+	} else {
+		m_matrix = NULL;
+		m_rows = 0;
+		m_columns = 0;
+	}
+}
+
+/*export*/ template <class T>
+Matrix<T>::Matrix(int rows, int columns) {
+	m_matrix = NULL;
+	resize(rows, columns);
+}
+
+/*export*/ template <class T>
+Matrix<T> &
+Matrix<T>::operator= (const Matrix<T> &other) {
+	if ( other.m_matrix != NULL ) {
+		// copy arrays
+		resize(other.m_rows, other.m_columns);
+		for ( int i = 0 ; i < m_rows ; i++ )
+			for ( int j = 0 ; j < m_columns ; j++ )
+				m_matrix[i][j] = other.m_matrix[i][j];
+	} else {
+		// free arrays
+		for ( int i = 0 ; i < m_columns ; i++ )
+			delete [] m_matrix[i];
+
+		delete [] m_matrix;
+
+		m_matrix = NULL;
+		m_rows = 0;
+		m_columns = 0;
+	}
+	
+	return *this;
+}
+
+/*export*/ template <class T>
+Matrix<T>::~Matrix() {
+	if ( m_matrix != NULL ) {
+		// free arrays
+		for ( int i = 0 ; i < m_rows ; i++ )
+			delete [] m_matrix[i];
+
+		delete [] m_matrix;
+	}
+	m_matrix = NULL;
+}
+
+/*export*/ template <class T>
+void
+Matrix<T>::resize(int rows, int columns) {
+	if ( m_matrix == NULL ) {
+		// alloc arrays
+		m_matrix = new T*[rows]; // rows
+		for ( int i = 0 ; i < rows ; i++ )
+			m_matrix[i] = new T[columns]; // columns
+
+		m_rows = rows;
+		m_columns = columns;
+		clear();
+	} else {
+		// save array pointer
+		T **new_matrix;
+		// alloc new arrays
+		new_matrix = new T*[rows]; // rows
+		for ( int i = 0 ; i < rows ; i++ ) {
+			new_matrix[i] = new T[columns]; // columns
+			for ( int j = 0 ; j < columns ; j++ )
+				new_matrix[i][j] = 0;
+		}
+
+		// copy data from saved pointer to new arrays
+		int minrows = std::min<int>(rows, m_rows);
+		int mincols = std::min<int>(columns, m_columns);
+		for ( int x = 0 ; x < minrows ; x++ )
+			for ( int y = 0 ; y < mincols ; y++ )
+				new_matrix[x][y] = m_matrix[x][y];
+
+		// delete old arrays
+		if ( m_matrix != NULL ) {
+			for ( int i = 0 ; i < m_rows ; i++ )
+				delete [] m_matrix[i];
+
+			delete [] m_matrix;
+		}
+
+		m_matrix = new_matrix;
+	}
+
+	m_rows = rows;
+	m_columns = columns;
+}
+
+/*export*/ template <class T>
+void
+Matrix<T>::identity() {
+	assert( m_matrix != NULL );
+
+	clear();
+
+	int x = std::min<int>(m_rows, m_columns);
+	for ( int i = 0 ; i < x ; i++ )
+		m_matrix[i][i] = 1;
+}
+
+/*export*/ template <class T>
+void
+Matrix<T>::clear() {
+	assert( m_matrix != NULL );
+
+	for ( int i = 0 ; i < m_rows ; i++ )
+		for ( int j = 0 ; j < m_columns ; j++ )
+			m_matrix[i][j] = 0;
+}
+
+/*export*/ template <class T>
+T 
+Matrix<T>::trace() {
+	assert( m_matrix != NULL );
+
+	T value = 0;
+
+	int x = std::min<int>(m_rows, m_columns);
+	for ( int i = 0 ; i < x ; i++ )
+		value += m_matrix[i][i];
+
+	return value;
+}
+
+/*export*/ template <class T>
+Matrix<T>& 
+Matrix<T>::transpose() {
+	assert( m_rows > 0 );
+	assert( m_columns > 0 );
+
+	int new_rows = m_columns;
+	int new_columns = m_rows;
+
+	if ( m_rows != m_columns ) {
+		// expand matrix
+		int m = std::max<int>(m_rows, m_columns);
+		resize(m,m);
+	}
+
+	for ( int i = 0 ; i < m_rows ; i++ ) {
+		for ( int j = i+1 ; j < m_columns ; j++ ) {
+			T tmp = m_matrix[i][j];
+			m_matrix[i][j] = m_matrix[j][i];
+			m_matrix[j][i] = tmp;
+		}
+	}
+
+	if ( new_columns != new_rows ) {
+		// trim off excess.
+		resize(new_rows, new_columns);
+	}
+
+	return *this;
+}
+
+/*export*/ template <class T>
+Matrix<T> 
+Matrix<T>::product(Matrix<T> &other) {
+	assert( m_matrix != NULL );
+	assert( other.m_matrix != NULL );
+	assert ( m_columns == other.m_rows );
+
+	Matrix<T> out(m_rows, other.m_columns);
+
+	for ( int i = 0 ; i < out.m_rows ; i++ ) {
+		for ( int j = 0 ; j < out.m_columns ; j++ ) {
+			for ( int x = 0 ; x < m_columns ; x++ ) {
+				out(i,j) += m_matrix[i][x] * other.m_matrix[x][j];
+			}
+		}
+	}
+
+	return out;
+}
+
+/*export*/ template <class T>
+T&
+Matrix<T>::operator ()(int x, int y) {
+	assert ( x >= 0 );
+	assert ( y >= 0 );
+	assert ( x < m_rows );
+	assert ( y < m_columns );
+	assert ( m_matrix != NULL );
+	return m_matrix[x][y];
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/include/utilities/munkres/matrix.h	Thu May 10 15:42:21 2012 -0700
@@ -0,0 +1,56 @@
+/*
+ *   Copyright (c) 2007 John Weaver
+ *
+ *   This program is free software; you can redistribute it and/or modify
+ *   it under the terms of the GNU General Public License as published by
+ *   the Free Software Foundation; either version 2 of the License, or
+ *   (at your option) any later version.
+ *
+ *   This program is distributed in the hope that it will be useful,
+ *   but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *   GNU General Public License for more details.
+ *
+ *   You should have received a copy of the GNU General Public License
+ *   along with this program; if not, write to the Free Software
+ *   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
+ */
+
+#if !defined(_MATRIX_H_)
+#define _MATRIX_H_
+
+template <class T>
+class Matrix {
+public:
+	Matrix();
+	Matrix(int rows, int columns);
+	Matrix(const Matrix<T> &other);
+	Matrix<T> & operator= (const Matrix<T> &other);
+	~Matrix();
+	// all operations except product modify the matrix in-place.
+	void resize(int rows, int columns);
+	void identity(void);
+	void clear(void);
+	T& operator () (int x, int y);
+	T trace(void);
+	Matrix<T>& transpose(void);
+	Matrix<T> product(Matrix<T> &other);
+	int minsize(void) {
+		return ((m_rows < m_columns) ? m_rows : m_columns);
+	}
+	int columns(void) {
+		return m_columns;
+	}
+	int rows(void) {
+		return m_rows;
+	}
+private:
+	T **m_matrix;
+	int m_rows;
+	int m_columns;
+};
+
+#include "matrix.hpp"
+
+#endif /* !defined(_MATRIX_H_) */
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/include/utilities/munkres/munkres.cpp	Thu May 10 15:42:21 2012 -0700
@@ -0,0 +1,359 @@
+/*
+ *   Copyright (c) 2007 John Weaver
+ *
+ *   This program is free software; you can redistribute it and/or modify
+ *   it under the terms of the GNU General Public License as published by
+ *   the Free Software Foundation; either version 2 of the License, or
+ *   (at your option) any later version.
+ *
+ *   This program is distributed in the hope that it will be useful,
+ *   but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *   GNU General Public License for more details.
+ *
+ *   You should have received a copy of the GNU General Public License
+ *   along with this program; if not, write to the Free Software
+ *   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
+ */
+
+#include "munkres.h"
+
+#include <iostream>
+#include <cmath>
+
+bool 
+Munkres::find_uncovered_in_matrix(double item, int &row, int &col) {
+  for ( row = 0 ; row < matrix.rows() ; row++ )
+    if ( !row_mask[row] )
+      for ( col = 0 ; col < matrix.columns() ; col++ )
+        if ( !col_mask[col] )
+          if ( matrix(row,col) == item )
+            return true;
+
+  return false;
+}
+
+bool 
+Munkres::pair_in_list(const std::pair<int,int> &needle, const std::list<std::pair<int,int> > &haystack) {
+  for ( std::list<std::pair<int,int> >::const_iterator i = haystack.begin() ; i != haystack.end() ; i++ ) {
+    if ( needle == *i )
+      return true;
+  }
+  
+  return false;
+}
+
+int 
+Munkres::step1(void) {
+  for ( int row = 0 ; row < matrix.rows() ; row++ )
+    for ( int col = 0 ; col < matrix.columns() ; col++ )
+      if ( matrix(row,col) == 0 ) {
+        bool isstarred = false;
+        for ( int nrow = 0 ; nrow < matrix.rows() ; nrow++ )
+          if ( mask_matrix(nrow,col) == STAR ) {
+            isstarred = true;
+            break;
+          }
+
+        if ( !isstarred ) {
+          for ( int ncol = 0 ; ncol < matrix.columns() ; ncol++ )
+            if ( mask_matrix(row,ncol) == STAR ) {
+              isstarred = true;
+              break;
+            }
+        }
+              
+        if ( !isstarred ) {
+          mask_matrix(row,col) = STAR;
+        }
+      }
+
+  return 2;
+}
+
+int 
+Munkres::step2(void) {
+  int rows = matrix.rows();
+  int cols = matrix.columns();
+  int covercount = 0;
+  for ( int row = 0 ; row < rows ; row++ )
+    for ( int col = 0 ; col < cols ; col++ )
+      if ( mask_matrix(row,col) == STAR ) {
+        col_mask[col] = true;
+        covercount++;
+      }
+      
+  int k = matrix.minsize();
+
+  if ( covercount >= k ) {
+#ifdef DEBUG
+    std::cout << "Final cover count: " << covercount << std::endl;
+#endif
+    return 0;
+  }
+
+#ifdef DEBUG
+  std::cout << "Munkres matrix has " << covercount << " of " << k << " Columns covered:" << std::endl;
+  for ( int row = 0 ; row < rows ; row++ ) {
+    for ( int col = 0 ; col < cols ; col++ ) {
+      std::cout.width(8);
+      std::cout << matrix(row,col) << ",";
+    }
+    std::cout << std::endl;
+  }
+  std::cout << std::endl;
+#endif
+
+
+  return 3;
+}
+
+int 
+Munkres::step3(void) {
+  /*
+  Main Zero Search
+
+   1. Find an uncovered Z in the distance matrix and prime it. If no such zero exists, go to Step 5
+   2. If No Z* exists in the row of the Z', go to Step 4.
+   3. If a Z* exists, cover this row and uncover the column of the Z*. Return to Step 3.1 to find a new Z
+  */
+  if ( find_uncovered_in_matrix(0, saverow, savecol) ) {
+    mask_matrix(saverow,savecol) = PRIME; // prime it.
+  } else {
+    return 5;
+  }
+
+  for ( int ncol = 0 ; ncol < matrix.columns() ; ncol++ )
+    if ( mask_matrix(saverow,ncol) == STAR ) {
+      row_mask[saverow] = true; //cover this row and
+      col_mask[ncol] = false; // uncover the column containing the starred zero
+      return 3; // repeat
+    }
+
+  return 4; // no starred zero in the row containing this primed zero
+}
+
+int 
+Munkres::step4(void) {
+  int rows = matrix.rows();
+  int cols = matrix.columns();
+
+  std::list<std::pair<int,int> > seq;
+  // use saverow, savecol from step 3.
+  std::pair<int,int> z0(saverow, savecol);
+  std::pair<int,int> z1(-1,-1);
+  std::pair<int,int> z2n(-1,-1);
+  seq.insert(seq.end(), z0);
+  int row, col = savecol;
+  /*
+  Increment Set of Starred Zeros
+
+   1. Construct the ``alternating sequence'' of primed and starred zeros:
+
+         Z0 : Unpaired Z' from Step 4.2 
+         Z1 : The Z* in the column of Z0
+         Z[2N] : The Z' in the row of Z[2N-1], if such a zero exists 
+         Z[2N+1] : The Z* in the column of Z[2N]
+
+      The sequence eventually terminates with an unpaired Z' = Z[2N] for some N.
+  */
+  bool madepair;
+  do {
+    madepair = false;
+    for ( row = 0 ; row < rows ; row++ )
+      if ( mask_matrix(row,col) == STAR ) {
+        z1.first = row;
+        z1.second = col;
+        if ( pair_in_list(z1, seq) )
+          continue;
+        
+        madepair = true;
+        seq.insert(seq.end(), z1);
+        break;
+      }
+
+    if ( !madepair )
+      break;
+
+    madepair = false;
+
+    for ( col = 0 ; col < cols ; col++ )
+      if ( mask_matrix(row,col) == PRIME ) {
+        z2n.first = row;
+        z2n.second = col;
+        if ( pair_in_list(z2n, seq) )
+          continue;
+        madepair = true;
+        seq.insert(seq.end(), z2n);
+        break;
+      }
+  } while ( madepair );
+
+  for ( std::list<std::pair<int,int> >::iterator i = seq.begin() ;
+      i != seq.end() ;
+      i++ ) {
+    // 2. Unstar each starred zero of the sequence.
+    if ( mask_matrix(i->first,i->second) == STAR )
+      mask_matrix(i->first,i->second) = NORMAL;
+
+    // 3. Star each primed zero of the sequence,
+    // thus increasing the number of starred zeros by one.
+    if ( mask_matrix(i->first,i->second) == PRIME )
+      mask_matrix(i->first,i->second) = STAR;
+  }
+
+  // 4. Erase all primes, uncover all columns and rows, 
+  for ( int row = 0 ; row < mask_matrix.rows() ; row++ )
+    for ( int col = 0 ; col < mask_matrix.columns() ; col++ )
+      if ( mask_matrix(row,col) == PRIME )
+        mask_matrix(row,col) = NORMAL;
+  
+  for ( int i = 0 ; i < rows ; i++ ) {
+    row_mask[i] = false;
+  }
+
+  for ( int i = 0 ; i < cols ; i++ ) {
+    col_mask[i] = false;
+  }
+
+  // and return to Step 2. 
+  return 2;
+}
+
+int 
+Munkres::step5(void) {
+  int rows = matrix.rows();
+  int cols = matrix.columns();
+  /*
+  New Zero Manufactures
+
+   1. Let h be the smallest uncovered entry in the (modified) distance matrix.
+   2. Add h to all covered rows.
+   3. Subtract h from all uncovered columns
+   4. Return to Step 3, without altering stars, primes, or covers. 
+  */
+  double h = 0;
+  for ( int row = 0 ; row < rows ; row++ ) {
+    if ( !row_mask[row] ) {
+      for ( int col = 0 ; col < cols ; col++ ) {
+        if ( !col_mask[col] ) {
+          if ( (h > matrix(row,col) && matrix(row,col) != 0) || h == 0 ) {
+            h = matrix(row,col);
+          }
+        }
+      }
+    }
+  }
+
+  for ( int row = 0 ; row < rows ; row++ )
+    if ( row_mask[row] )
+      for ( int col = 0 ; col < cols ; col++ )
+        matrix(row,col) += h;
+  
+  for ( int col = 0 ; col < cols ; col++ )
+    if ( !col_mask[col] )
+      for ( int row = 0 ; row < rows ; row++ )
+        matrix(row,col) -= h;
+
+  return 3;
+}
+
+void 
+Munkres::solve(Matrix<double> &m) {
+  // Linear assignment problem solution
+  // [modifies matrix in-place.]
+  // matrix(row,col): row major format assumed.
+
+  // Assignments are remaining 0 values
+  // (extra 0 values are replaced with -1)
+#ifdef DEBUG
+  std::cout << "Munkres input matrix:" << std::endl;
+  for ( int row = 0 ; row < m.rows() ; row++ ) {
+    for ( int col = 0 ; col < m.columns() ; col++ ) {
+      std::cout.width(8);
+      std::cout << m(row,col) << ",";
+    }
+    std::cout << std::endl;
+  }
+  std::cout << std::endl;
+#endif
+
+  double highValue = 0;
+  for ( int row = 0 ; row < m.rows() ; row++ ) {
+    for ( int col = 0 ; col < m.columns() ; col++ ) {
+      if ( m(row,col) != INFINITY && m(row,col) > highValue )
+        highValue = m(row,col);
+    }
+  }
+  highValue++;
+  
+  for ( int row = 0 ; row < m.rows() ; row++ )
+    for ( int col = 0 ; col < m.columns() ; col++ )
+      if ( m(row,col) == INFINITY )
+        m(row,col) = highValue;
+
+  bool notdone = true;
+  int step = 1;
+
+  this->matrix = m;
+  // STAR == 1 == starred, PRIME == 2 == primed
+  mask_matrix.resize(matrix.rows(), matrix.columns());
+
+  row_mask = new bool[matrix.rows()];
+  col_mask = new bool[matrix.columns()];
+  for ( int i = 0 ; i < matrix.rows() ; i++ ) {
+    row_mask[i] = false;
+  }
+
+  for ( int i = 0 ; i < matrix.columns() ; i++ ) {
+    col_mask[i] = false;
+  }
+
+  while ( notdone ) {
+    switch ( step ) {
+      case 0:
+        notdone = false;
+        break;
+      case 1:
+        step = step1();
+        break;
+      case 2:
+        step = step2();
+        break;
+      case 3:
+        step = step3();
+        break;
+      case 4:
+        step = step4();
+        break;
+      case 5:
+        step = step5();
+        break;
+    }
+  }
+
+  // Store results
+  for ( int row = 0 ; row < matrix.rows() ; row++ )
+    for ( int col = 0 ; col < matrix.columns() ; col++ )
+      if ( mask_matrix(row,col) == STAR )
+        matrix(row,col) = 0;
+      else
+        matrix(row,col) = -1;
+
+#ifdef DEBUG
+  std::cout << "Munkres output matrix:" << std::endl;
+  for ( int row = 0 ; row < matrix.rows() ; row++ ) {
+    for ( int col = 0 ; col < matrix.columns() ; col++ ) {
+      std::cout.width(1);
+      std::cout << matrix(row,col) << ",";
+    }
+    std::cout << std::endl;
+  }
+  std::cout << std::endl;
+#endif
+
+  m = matrix;
+
+  delete [] row_mask;
+  delete [] col_mask;
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/include/utilities/munkres/munkres.h	Thu May 10 15:42:21 2012 -0700
@@ -0,0 +1,49 @@
+/*
+ *   Copyright (c) 2007 John Weaver
+ *
+ *   This program is free software; you can redistribute it and/or modify
+ *   it under the terms of the GNU General Public License as published by
+ *   the Free Software Foundation; either version 2 of the License, or
+ *   (at your option) any later version.
+ *
+ *   This program is distributed in the hope that it will be useful,
+ *   but WITHOUT ANY WARRANTY; without even the implied warranty of
+ *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ *   GNU General Public License for more details.
+ *
+ *   You should have received a copy of the GNU General Public License
+ *   along with this program; if not, write to the Free Software
+ *   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
+ */
+
+#if !defined(_MUNKRES_H_)
+#define _MUNKRES_H_
+
+#include "matrix.h"
+
+#include <list>
+#include <utility>
+
+class Munkres {
+public:
+	void solve(Matrix<double> &m);
+private:
+  static const int NORMAL = 0;
+  static const int STAR = 1;
+  static const int PRIME = 2; 
+	inline bool find_uncovered_in_matrix(double,int&,int&);
+	inline bool pair_in_list(const std::pair<int,int> &, const std::list<std::pair<int,int> > &);
+	int step1(void);
+	int step2(void);
+	int step3(void);
+	int step4(void);
+	int step5(void);
+	int step6(void);
+	Matrix<int> mask_matrix;
+	Matrix<double> matrix;
+	bool *row_mask;
+	bool *col_mask;
+	int saverow, savecol;
+};
+
+#endif /* !defined(_MUNKRES_H_) */