------------------------------------------------------------------------
-- Based on "Deriving Target Code as a Representation of Continuation
-- Semantics" by Mitchell Wand
------------------------------------------------------------------------

open import Relation.Binary.PropositionalEquality

module Wand
         -- Assume extensionality; it is used in proofs below.
         (ext : forall {a b : Set} {f g : a -> b} ->
                  (forall x -> f x  g x) -> f  g)
         -- The type used for identifiers.
         (Id : Set)
         where

import MultiComposition as MC
open MC ext
open TypeList using (_⟶_)

open import Data.Nat
open import Derivation
open ≡-Reasoning
open import Function hiding (id)
open import Data.Vec

------------------------------------------------------------------------
-- Language

module Language where

  -- Wand's simple expression language with variables:

  data Exp : Set where
    id    : Id -> Exp
    [_+_] : Exp -> Exp -> Exp

  -- Domains for values, stores, commands and continuations.

  V : Set
  V = 

  S : Set
  S = Id -> V

  C : Set
  C = S -> V

  K : Set
  K = V -> C

  -- The semantics, in continuation-passing style.

  E : Exp -> K -> C
  E (id I)      = \κ σ -> κ (σ I) σ
  E [ e₁ + e₂ ] = \κ -> E e₁ (\v₁ -> E e₂ (\v₂ -> κ (v₁ + v₂)))

  P : Exp -> C
  P e = E e (\v σ -> v)

------------------------------------------------------------------------
-- Deriving a compiler

module Step₁ where

  open Language

  mutual

    -- Let us derive a function E', equal to E, which does not use E
    -- recursively under binders, and in which all recursive
    -- occurrences of E' <something> are applied using _⟪_⟫_.

    -- First the remaining continuations are named. (Even though this
    -- is the last step of the derivation these functions have to be
    -- defined before E'P, because otherwise they don't evaluate in
    -- the proof.)

    fetch : Id -> K -> C
    fetch I = \κ σ -> κ (σ I) σ

    add : K -> V -> V -> C
    add = \κ v₁ v₂ -> κ (v₁ + v₂)

    E' : Exp -> K -> C
    E' e = witness (E'P e)

    E'P : (e : Exp) -> EqualTo (E e)
    E'P (id I) =  begin
      (\κ σ -> κ (σ I) σ)
        ≡⟨ refl 
      fetch I
        
    E'P [ e₁ + e₂ ] =  begin
      (\κ σ -> E e₁ (\v₁ -> E e₂ (\v₂ -> κ (v₁ + v₂))) σ)
        ≡⟨ refl 
      (\κ -> E e₁ (\v₁ -> E e₂ (\v₂ -> κ (v₁ + v₂))))
        ≡⟨ refl 
      E e₁  K  ε  (\κ v₁ -> E e₂ (\v₂ -> κ (v₁ + v₂)))
        ≡⟨ refl 
      E e₁  K  ε  (E e₂  K  V  ε  (\κ v₁ v₂ -> κ (v₁ + v₂)))
        ≡⟨ proof (E'P e₁)  K  ε ⟫-cong
             (proof (E'P e₂)  K  V  ε ⟫-cong refl) 
      E' e₁  K  ε  (E' e₂  K  V  ε  (\κ v₁ v₂ -> κ (v₁ + v₂)))
        ≡⟨ refl 
      E' e₁  K  ε  (E' e₂  K  V  ε  add)
        

  mutual

    -- Let's give P the same treatment.

    halt : K
    halt = \v σ -> v

    P' : Exp -> C
    P' e = witness (P'P e)

    P'P : (e : Exp) -> EqualTo (P e)
    P'P e =  begin
      E e (\v σ -> v)
        ≡⟨ refl 
      E e  ε  (\v σ -> v)
        ≡⟨ proof (E'P e)  ε ⟫-cong refl 
      E' e  ε  halt
        

module Step₂ where

  open Language

  -- Wrap up what we have achieved so far in a data type.

  data Exp' : Set -> Set1 where
    B     : forall {a b} ts ->
            Exp' (a -> b) -> Exp' (ts  a) -> Exp' (ts  b)
    add   : Exp' (K -> V -> V -> C)
    fetch : Id -> Exp' (K -> C)
    halt  : Exp' K

  -- The semantics of values of this data type.

  ⟦_⟧ : forall {t} -> Exp' t -> t
   B ts e₁ e₂  =  e₁   ts   e₂ 
   add         = Step₁.add
   fetch I     = Step₁.fetch I
   halt        = Step₁.halt

  -- Translation (compiler) from Exp to Exp':

  comp : Exp -> Exp' (K -> C)
  comp (id I)      = fetch I
  comp [ e₁ + e₂ ] =
    B (K  ε) (comp e₁) (B (K  V  ε) (comp e₂) add)

  -- Correctness proof:

  correct : forall e -> Step₁.E' e   comp e 
  correct (id I)      = refl
  correct [ e₁ + e₂ ] =
    ext (\κ -> cong₂ _$_
               (correct e₁)
               (ext \v₁ -> cong (\e -> e (Step₁.add κ v₁))
                                (correct e₂)))

  -- Note that we could derive the compiler, but there are several
  -- reasons for not doing this:
  --
  -- • The implementation is obvious, we're just transforming the
  --   result from Step₁.
  --
  -- • The correctness result is of the form
  --     ... ≡ f (comp e)
  --   instead of the simpler form
  --     ... ≡ comp e.
  --   This means that the witness cannot be inferred automatically,
  --   so we have to write it out, and things start to look messy. See
  --   below.

  module DerivedCompiler where

   data Σ₁ (a : Set1) (b : a -> Set) : Set1 where
     _,_ : (x : a) -> b x -> Σ₁ a b

   witness₁ : forall {a b} -> Σ₁ a b -> a
   witness₁ (x , y) = x

   proof₁ : forall {a b} -> (p : Σ₁ a b) -> b (witness₁ p)
   proof₁ (x , y) = y

   mutual

    comp' : Exp -> Exp' (K -> C)
    comp' e = witness₁ (comp'P e)

    comp'P : (e : Exp) -> Σ₁ _ \e' -> Step₁.E' e   e' 
    comp'P (id I) =
      ( fetch I
      , (begin
          Step₁.fetch I
            ≡⟨ refl 
           fetch I 
            )
      )
    comp'P [ e₁ + e₂ ] =
      ( B (K  ε) (comp' e₁) (B (K  V  ε) (comp' e₂) add)
      ,
        (begin
          Step₁.E' e₁  K  ε  (Step₁.E' e₂  K  V  ε  Step₁.add)
            ≡⟨ ext (\κ -> cong₂ _$_
                 (proof₁ (comp'P e₁))
                 (ext \v₁ -> cong (\e -> e (Step₁.add κ v₁))
                                  (proof₁ (comp'P e₂)))) 
           B (K  ε) (comp' e₁) (B (K  V  ε) (comp' e₂) add) 
            )
      )

  -- We can combine the correctness proofs from above.

  correct' : forall e -> E e   comp e 
  correct' e = begin
    E e
      ≡⟨ proof (Step₁.E'P e) 
    Step₁.E' e
      ≡⟨ correct e 
     comp e 
      

module Step₂′ where

  open Language

  -- Step₂ can be simplified by indexing the Exp' type on the
  -- semantics of expressions:

  data Exp' : {t : Set} -> t -> Set1 where
    B     : forall {a b} ts {f : a -> b} {g : ts  a} ->
            Exp' f -> Exp' g -> Exp' (f  ts  g)
    add   : Exp' Step₁.add
    fetch : (I : Id) -> Exp' (Step₁.fetch I)
    halt  : Exp' Step₁.halt

  -- The semantics is encoded in the type:

  ⟦_⟧ : forall {t} {f : t} -> Exp' f -> t
  ⟦_⟧ {f = f} _ = f

  -- Correctness of compilation is ensured through the type of comp:

  comp : (e : Exp) -> Exp' (Step₁.E' e)
  comp (id I)      = fetch I
  comp [ e₁ + e₂ ] =
    B (K  ε) (comp e₁) (B (K  V  ε) (comp e₂) add)

  -- Note: No proof code necessary.

  -- We can still combine the proofs:

  correct : forall e -> E e   comp e 
  correct e = begin
    E e
      ≡⟨ proof (Step₁.E'P e) 
    Step₁.E' e
      ≡⟨ refl 
     comp e 
      

module Examples (X Y Z : Id) where

  open Language
  open Step₂′

  -- Fig. 2(b) in Wand's paper.

  fig2b : Exp' {C} _
  fig2b = B ε l halt
    where
    llr : Exp' {K -> V -> C} _
    llr = B (K  V  ε) (fetch Y) add

    ll : Exp' {K -> C} _
    ll = B (K  ε) (fetch X) llr

    lr : Exp' {K -> V -> C} _
    lr = B (K  V  ε) (fetch Z) add

    l : Exp' {K -> C} _
    l = B (K  ε) ll lr

  -- Fig. 2(c) in Wand's paper.

  fig2c : Exp' {C} _
  fig2c = B ε (fetch X) step₂
    where
    step₆ : Exp' {V -> C} _
    step₆ = halt

    step₅ : Exp' {V -> V -> C} _
    step₅ = B ε add step₆

    step₄ : Exp' {V -> C} _
    step₄ = B (V  ε) (fetch Z) step₅

    step₃ : Exp' {V -> V -> C} _
    step₃ = B ε add step₄

    step₂ : Exp' {V -> C} _
    step₂ = B (V  ε) (fetch Y) step₃

module Step₂″ where

  -- Yet another variant of Step₂. Here I have chosen to specialise
  -- the types (in order to make some of Wand's invariants explicit,
  -- and simplify the next step), and to include a version of the
  -- top-level compiler (P).

  open Language

  data Exp' : (n : ) -> (K -> V ^ n  C) -> Set where
    B     : forall {n f g} ->
            Exp' 0 f -> Exp' (suc n) g -> Exp' n (B′ n f g)
    add   : Exp' 2 Step₁.add
    fetch : (I : Id) -> Exp' 0 (Step₁.fetch I)

  data ToplevelExp : C -> Set where
    Bhalt : forall {f} -> Exp' 0 f -> ToplevelExp (f Step₁.halt)

  -- Correctness of compilation is ensured through the type of comp:

  comp' : (e : Exp) -> Exp' 0 (Step₁.E' e)
  comp' (id I)      = fetch I
  comp' [ e₁ + e₂ ] = B (comp' e₁) (B (comp' e₂) add)

  -- Note: Now we are using some proof code. See the definition of B′,
  -- which is used in the type of B above.

  -- Top-level compiler:

  comp : (e : Exp) -> ToplevelExp (Step₁.P' e)
  comp e = Bhalt (comp' e)

  -- Correctness:

  ⟦_⟧ : forall {f} -> ToplevelExp f -> C
  ⟦_⟧ {f = f} _ = f

  correct : forall e -> P e   comp e 
  correct e = begin
    P e
      ≡⟨ proof (Step₁.P'P e) 
    Step₁.P' e
      ≡⟨ refl 
     comp e 
      

module Step₃ where

  open Language

  -- Let's implement Wand's rotation. The goal is to take the
  -- tree-like structure from step 2 and flatten it.

  data Command : (n : ) -> (K -> V ^ n  C) -> Set where
    add   : Command 2 Step₁.add
    fetch : (I : Id) -> Command 0 (Step₁.fetch I)

  mutual

    Cont : (m : ) -> (K -> V ^ m  C) -> Set
    Cont m f = forall {n g} ->
               LinearExp (suc n) g -> LinearExp (m + n) (B″ m n f g)

    data LinearExp : (n : ) -> V ^ n  C -> Set where
      _○_  : forall {m f} -> Command m f -> Cont m f
      halt : LinearExp 1 Step₁.halt

  -- Wand's rotation implemented in a continuation-passing style to
  -- ensure that it is well-typed and structurally recursive:

  rot : forall {m f} -> Step₂″.Exp' m f -> Cont m f
  rot Step₂″.add                           = \k -> add  k
  rot (Step₂″.fetch I)                     = \k -> fetch I  k
  rot (Step₂″.B {m} {f} {g} e₁ e₂) {n} {h} = \k -> cast₁ (rot e₁ (rot e₂ k))
    where
    cast₁ = subst (LinearExp (m + n)) (rot-lemma m n f g h)

  toplevelRot : forall {f} -> Step₂″.ToplevelExp f -> LinearExp 0 f
  toplevelRot (Step₂″.Bhalt e) = rot e halt

  -- Check out MultiComposition.agda to see the proofs needed to
  -- verify this step. The proofs look difficult, but the core result
  -- (right-assoc) is very simple; the rest of the code is boring
  -- impedance-matching (ensuring that functions and arguments have
  -- matching types).

module Step₄ where

  -- Let's define a direct interpreter (virtual machine) for these
  -- linear expressions. The interpreter works on a stack of a given
  -- size.

  open Language
  open Step₃

  Stack :  -> Set
  Stack = Vec V

  ⟦_⟧' : forall {a n} -> V ^ n  a -> Stack n -> a
   f ⟧' []      = f
   f ⟧' (x  s) =  f x ⟧' s

  ⟦_⟧ : forall {a n} -> V ^ n  a -> Stack n -> a
   f  s =  f ⟧' (reverse s)

  postulate

    -- I cannot be bothered to prove these lemmas. The presence of
    -- "reverse" makes the proofs rather complicated.

    lemma₁ : forall {n} g x y (s : Stack n) σ ->
              B″ 2 n Step₁.add g  (x  y  s) σ   g  (y + x  s) σ

    lemma₂ : forall {n} g I (s : Stack n) σ ->
              B″ 0 n (Step₁.fetch I) g  s σ   g  (σ I  s) σ

  -- Note: I have not derived the implementation of exec, I have just
  -- combined it with the correctness proof.

  mutual

    exec : forall {n f} -> LinearExp n f -> Stack n -> C
    exec e s σ = witness (execP e s σ)

    execP : forall {n f} (e : LinearExp n f) (s : Stack n) (σ : S) ->
            EqualTo ( f  s σ)
    execP (_○_ add {n} {g} c) (x  y  s) σ =  begin
       B″ 2 n Step₁.add g  (x  y  s) σ
        ≡⟨ lemma₁ g x y s σ 
       g  (y + x  s) σ
        ≡⟨ proof (execP c (y + x  s) σ) 
      exec c (y + x  s) σ
        
    execP (_○_ (fetch I) {n} {g} c) s σ =  begin
       B″ 0 n (Step₁.fetch I) g  s σ
        ≡⟨ lemma₂ g I s σ 
       g  (σ I  s) σ
        ≡⟨ proof (execP c (σ I  s) σ) 
      exec c (σ I  s) σ
        
    execP halt (x  []) σ =  begin x 

module Step₅ where

  open Language
  open Step₃ using () renaming (LinearExp to Code)
  open Step₄ using (exec)

  -- Now we can wrap up the development.

  -- The full compiler.

  comp : (e : Exp) -> Code 0 (Step₁.P' e)
  comp e = Step₃.toplevelRot (Step₂″.comp e)

  -- Full correctness.

  correct : forall e σ -> P e σ  exec (comp e) [] σ
  correct e σ = begin
    P e σ
      ≡⟨ cong (\f -> f σ) (proof (Step₁.P'P e)) 
    Step₁.P' e σ
      ≡⟨ proof (Step₄.execP (comp e) [] σ) 
    exec (comp e) [] σ