Tail-Recursive Tree Search in OCaml Example

Implement a function find : ('a -> bool) -> 'a tree -> 'a option such that find p t traverses t in depth-first order and returns Option.Some x if x is an element of t that satisfies p. If no such value is found, then find p t returns Option.None. Use the following definition for the type 'a tree, and make sure find is tail-recursive.

ocaml
type 'a tree =
  | Empty
  | Node of {
    contents : 'a;
    children : 'a tree list
  }

Solution

A naive implementation of the find function would be as follows:

ocaml
let rec find : ('a -> bool) -> 'a tree -> 'a option =
  fun p t ->
    match t with
    | Empty -> Option.none
    | Node { contents; children } ->
        if p contents then
          Option.some contents
        else
          List.find_opt (find p) children

Recursive calls to find occur in the call to List.find_opt, which means that the function is not tail-recursive. To resolve this, we need to use continuation-passing style, and implement our own version of List.find_opt.

We introduce the mutually tail-recursive helper functions find_in_tree_tl and find_in_tree_list_tl that respectively search for the value in a tree and in a list of trees.

ocaml
let rec find_in_tree_tl :
    ('a -> bool) -> 'a tree ->
    (* Success continuation *) ('a -> 'b) ->
    (* Failure continuation *) (unit -> 'b) ->
    'b =
  fun p t succeed fail ->
    match t with
    | Empty -> fail ()
    | Node { contents; children } ->
        if p contents then
          succeed contents
        else
          find_in_tree_list_tl p children fail succeed

and find_in_tree_list_tl :
    ('a -> bool) -> 'a tree list ->
    (* Success continuation *) ('a -> 'b) ->
    (* Failure continuation *) (unit -> 'b) ->
    'b =
  fun p ts succeed fail ->
    match ts with
    | [] -> fail ()
    | t :: rest ->
        find_in_tree_tl p t
          succeed
          (fun () -> find_in_tree_list_tl p rest succeed fail)

let rec find : ('a -> bool) -> 'a tree -> 'a option =
  fun p t ->
    find_in_tree_tl p t
      (fun x -> Option.some x)
      (fun () -> Option.none)

The use of the continuation-passing style replaces stack memory allocations with heap memory allocations, both of which scaling with respect to the input tree's height.