Haskell ways to the 3n+1 problem

functional programminghaskellprogramming practices

Here is a simple programming problem from SPOJ: http://www.spoj.com/problems/PROBTRES/.

Basically, you are asked to output the biggest Collatz cycle for numbers between i and j. (Collatz cycle of a number $n$ is the number of steps to eventually get from $n$ to 1.)

I have been looking for a Haskell way to solve the problem with comparative performance than that of Java or C++ (so as to fits in the allowed run-time limit). Although a simple Java solution that memoizes the cycle length of any already computed cycles will work, I haven't been successful at applying the idea to obtain a Haskell solution.

I have tried the Data.Function.Memoize, as well as home-brewed log time memoization technique using the idea from this post: https://stackoverflow.com/questions/3208258/memoization-in-haskell. Unfortunately, memoization actually makes the computation of cycle(n) even slower. I believe the slow down comes from the overhead of the Haskell way. (I tried running with the compiled binary code, instead of interpreting.)

I also suspect that simply iterating numbers from i to j can be costly ($i,j\le10^6$). So I even tried precompute everything for the range query, using idea from http://blog.openendings.net/2013/10/range-trees-and-profiling-in-haskell.html. However, this still gives "Time Limit Exceeding" error.

Can you help to inform a neat competitive Haskell program for this?

Best Answer

I'll answer in Scala, because my Haskell isn't as fresh, and so people will believe this is a general functional programming algorithm question. I'll stick to data structures and concepts that are readily transferable.

We can start with a function that generates a collatz sequence, which is relatively straightforward, except for needing to pass the result as an argument to make it tail recursive:

def collatz(n: Int, result: List[Int] = List()): List[Int] = {
   if (n == 1) {
     1 :: result
   } else if ((n & 1) == 1) {
     collatz(3 * n + 1, n :: result)
   } else {
     collatz(n / 2, n :: result)
   }
 }

This actually puts the sequence in reverse order, but that's perfect for our next step, which is to store the lengths in a map:

def calculateLengths(sequence: List[Int], length: Int,
  lengths: Map[Int, Int]): Map[Int, Int] = sequence match {
    case Nil     => lengths
    case x :: xs => calculateLengths(xs, length + 1, lengths + ((x, length)))
}

You would call this with the answer from the first step, the initial length, and an empty map, like calculateLengths(collatz(22), 1, Map.empty)). This is how you memoize the result. Now we need to modify collatz to be able to use this:

def collatz(n: Int, lengths: Map[Int, Int], result: List[Int] = List()): (List[Int], Int) = {
  if (lengths contains n) {
     (result, lengths(n))
  } else if ((n & 1) == 1) {
    collatz(3 * n + 1, lengths, n :: result)
  } else {
    collatz(n / 2, lengths, n :: result)
  }
}

We eliminate the n == 1 check because we can just initialize the map with 1 -> 1, but we need to add 1 to the lengths we put in the map inside calculateLengths. It now also returns the memoized length where it stopped recursing, which we can use to initialize calculateLengths, like:

val initialMap = Map(1 -> 1)
val (result, length) = collatz(22, initialMap)
val newMap = calculateLengths(result, lengths, initialMap)

Now we have relatively efficient implementations of the pieces, we need to find a way to feed the results of the previous calculation into the input of the next calculation. This is called a fold, and looks like:

def iteration(lengths: Map[Int, Int], n: Int): Map[Int, Int] = {
  val (result, length) = collatz(n, lengths)
  calculateLengths(result, length, lengths)
}

val lengths = (1 to 10).foldLeft(Map(1 -> 1))(iteration)

Now to find the actual answer, we just need to filter the keys in the map between the given range, and find the max value, giving a final result of:

def answer(start: Int, finish: Int): Int = {
  val lengths = (start to finish).foldLeft(Map(1 -> 1))(iteration)
  lengths.filterKeys(x => x >= start && x <= finish).values.max
}

In my REPL for ranges of size 1000 or so, like the example input, the answer returns pretty much instantaneously.