Library zoo.program_logic.structural_equality

From zoo Require Import
  prelude.
From zoo.common Require Import
  list.
From zoo.language Require Import
  notations.
From zoo.program_logic Require Export
  wp.
From zoo.diaframe Require Import
  diaframe.
From zoo Require Import
  options.

Implicit Types b : bool.
Implicit Types tag : nat.
Implicit Types n : Z.
Implicit Types l : location.
Implicit Types gen : generativity.
Implicit Types v w : val.
Implicit Types vs : list val.
Implicit Types lv : lowval.

#[local] Definition __zoo_recs := (
  recs: "structeq" "v1" "v2"
    if: IsImmediate "v1" then
      if: IsImmediate "v2" then
        "v1" == "v2"
      else
        false
    else if: IsImmediate "v2" then
      false
    else (
      GetTag "v1" == GetTag "v2" and
      let: "sz" := GetSize "v1" in
      "sz" == GetSize "v2" and
      "structeq_aux" "v1" "v2" "sz"
    )
  and: "structeq_aux" "v1" "v2" "i"
    if: "i" == 0 then
      true
    else
      let: "i" := "i" - 1 in
      "structeq" (Load "v1" "i") (Load "v2" "i") and
      "structeq_aux" "v1" "v2" "i"
)%zoo_recs.
Definition structeq :=
  ValRecs 0 __zoo_recs.
#[local] Definition structeq_aux :=
  ValRecs 1 __zoo_recs.
#[global] Instance :
  AsValRecs' structeq 0 __zoo_recs [
    structeq ;
    structeq_aux
  ].
#[global] Instance :
  AsValRecs' structeq_aux 1 __zoo_recs [
    structeq ;
    structeq_aux
  ].

Notation "e1 = e2" := (
  App (App (Val structeq) e1%E) e2%E
)(at level 70,
  no associativity
) : expr_scope.
Notation "e1 ≠ e2" := (
  Unop UnopNeg (App (App (Val structeq) e1%E) e2%E)
)(at level 70,
  no associativity
) : expr_scope.

Record structeq_field := StructeqField
  { structeq_field_dfrac : dfrac
  ; structeq_field_val : val
  }.
Add Printing Constructor structeq_field.
Implicit Types fld : structeq_field.

#[global] Instance structeq_field_inhabited : Inhabited structeq_field :=
  populate
    {|structeq_field_dfrac := inhabitant
    ; structeq_field_val := inhabitant
    |}.

Record structeq_block := StructeqBlock
  { structeq_block_tag : nat
  ; structeq_block_fields : list structeq_field
  }.
Add Printing Constructor structeq_block.
Implicit Types blk : structeq_block.
Implicit Types footprint : gmap location structeq_block.

#[global] Instance structeq_block_inhabited : Inhabited structeq_block :=
  populate
    {|structeq_block_tag := inhabitant
    ; structeq_block_fields := inhabitant
    |}.

Fixpoint val_traversable footprint v :=
  match v with
  | ValBool _
  | ValInt _
      True
  | ValLoc l
      l dom footprint
  | ValBlock _ _ vs
      Forall' (val_traversable footprint) vs
  | _
      False
  end.
#[global] Arguments val_traversable _ !_ / : assert.

Section zoo_G.
  Context `{zoo_G : !ZooG Σ}.

  Definition structeq_footprint footprint : iProp Σ :=
    [∗ map] l blk footprint,
      l ↦ₕ Header blk.(structeq_block_tag) (length blk.(structeq_block_fields))
      [∗ list] i fld blk.(structeq_block_fields),
        (l +ₗ i) {fld.(structeq_field_dfrac)} fld.(structeq_field_val)
        val_traversable footprint fld.(structeq_field_val).

  Lemma structeq_footprint_empty :
     structeq_footprint .

  Lemma structeq_footprint_header {footprint} l blk :
    footprint !! l = Some blk
    structeq_footprint footprint
    l ↦ₕ Header blk.(structeq_block_tag) (length blk.(structeq_block_fields)).

  Lemma structeq_footprint_lookup {footprint} l blk (i : nat) fld :
    footprint !! l = Some blk
    blk.(structeq_block_fields) !! i = Some fld
    structeq_footprint footprint
      (l +ₗ i) {fld.(structeq_field_dfrac)} fld.(structeq_field_val)
      val_traversable footprint fld.(structeq_field_val)
      ( (l +ₗ i) {fld.(structeq_field_dfrac)} fld.(structeq_field_val) -∗
        structeq_footprint footprint
      ).
  Lemma structeq_footprint_lookup' {footprint} l blk i :
    footprint !! l = Some blk
    i < length blk.(structeq_block_fields)
    structeq_footprint footprint
       fld,
      blk.(structeq_block_fields) !! i = Some fld
      (l +ₗ i) {fld.(structeq_field_dfrac)} fld.(structeq_field_val)
      val_traversable footprint fld.(structeq_field_val)
      ( (l +ₗ i) {fld.(structeq_field_dfrac)} fld.(structeq_field_val) -∗
        structeq_footprint footprint
      ).

  Lemma structeq_footprint_wp_tag {footprint} l blk :
    footprint !! l = Some blk
    {{{
      structeq_footprint footprint
    }}}
      GetTag #l
    {{{
      RET #(encode_tag blk.(structeq_block_tag));
      structeq_footprint footprint
    }}}.
  Lemma structeq_footprint_wp_size {footprint} l blk :
    footprint !! l = Some blk
    {{{
      structeq_footprint footprint
    }}}
      GetSize #l
    {{{
      RET #(length blk.(structeq_block_fields));
      structeq_footprint footprint
    }}}.

  Lemma structeq_footprint_wp_load {footprint} l blk (i : nat) fld :
    footprint !! l = Some blk
    blk.(structeq_block_fields) !! i = Some fld
    {{{
      structeq_footprint footprint
    }}}
      Load #l #i
    {{{
      RET fld.(structeq_field_val);
      val_traversable footprint fld.(structeq_field_val)
      structeq_footprint footprint
    }}}.
  Lemma structeq_footprint_wp_load' {footprint} l blk i :
    footprint !! l = Some blk
    i < length blk.(structeq_block_fields)
    {{{
      structeq_footprint footprint
    }}}
      Load #l #i
    {{{
      fld
    , RET fld.(structeq_field_val);
      blk.(structeq_block_fields) !! i = Some fld
      val_traversable footprint fld.(structeq_field_val)
      structeq_footprint footprint
    }}}.
End zoo_G.

Fixpoint val_reachable footprint src path dst :=
  match path with
  | []
      src = dst
  | i :: path
      match src with
      | ValLoc l
          match footprint !! l with
          | None
              False
          | Some blk
              match blk.(structeq_block_fields) !! i with
              | None
                  False
              | Some fld
                  val_reachable footprint fld.(structeq_field_val) path dst
              end
          end
      | ValBlock _ _ vs
          match vs !! i with
          | None
              False
          | Some src
              val_reachable footprint src path dst
          end
      | _
          False
      end
  end.
#[global] Arguments val_reachable _ !_ !_ / _ : assert.

#[global] Instance val_reachable_dec footprint src path dst :
  Decision (val_reachable footprint src path dst).

Definition lowval_compatible footprint lv1 lv2 :=
  match lv1 with
  | LowvalLit lit1
      match lit1 with
      | LowlitLoc l1
          match lv2 with
          | LowvalLoc l2
              let blk1 := footprint !!! l1 in
              let blk2 := footprint !!! l2 in
              blk1.(structeq_block_tag) blk2.(structeq_block_tag) &&
              length blk1.(structeq_block_fields) length blk2.(structeq_block_fields)
          | LowvalBlock _ tag2 vs2 _
              let blk1 := footprint !!! l1 in
              blk1.(structeq_block_tag) tag2 &&
              length blk1.(structeq_block_fields) length vs2
          | _
              false
          end
      | _
          bool_decide (lv2 = LowvalLit lit1)
      end
  | LowvalRecs
      bool_decide (lv2 = LowvalRecs)
  | LowvalBlock _ tag1 vs1 _
      match lv2 with
      | LowvalLoc l2
          let blk2 := footprint !!! l2 in
          tag1 blk2.(structeq_block_tag) &&
          length vs1 length blk2.(structeq_block_fields)
      | LowvalBlock _ tag2 vs2 _
          tag1 tag2 &&
          length vs1 length vs2
      | _
          false
      end
  end.
#[global] Arguments lowval_compatible _ !_ !_ / : assert.

Definition val_compatible footprint v1 v2 :=
  lowval_compatible footprint (val_to_low v1) (val_to_low v2).

Definition val_structeq footprint v1 v2 :=
   path v1' v2',
  val_reachable footprint v1 path v1'
  val_reachable footprint v2 path v2'
  val_compatible footprint v1' v2' = true.

Definition val_structneq footprint v1 v2 :=
   path v1' v2',
  val_reachable footprint v1 path v1'
  val_reachable footprint v2 path v2'
  val_compatible footprint v1' v2' = false.

Lemma val_immediate_structeq footprint v1 v2 :
  val_immediate v1
  val_immediate v2
  v1 v2
  val_structeq footprint v1 v2.
Lemma val_immediate_structneq footprint v1 v2 :
  val_immediate v1
  val_immediate v2
  v1 v2
  val_structneq footprint v1 v2.

Lemma val_structeq_refl footprint v :
  val_immediate v
  val_structeq footprint v v.
Lemma val_structeq_refl' footprint v1 v2 :
  v1 = v2
  val_immediate v1
  val_structeq footprint v1 v2.

Section zoo_G.
  Context `{zoo_G : !ZooG Σ}.

  #[local] Lemma structeq𑁒spec_aux :
     (
       v1 v2 footprint,
      {{{
        val_traversable footprint v1
        val_traversable footprint v2
        structeq_footprint footprint
      }}}
        v1 = v2
      {{{
        b
      , RET #b;
        (if b then val_structeq else val_structneq) footprint v1 v2
        structeq_footprint footprint
      }}}
    ) (
       l1 blk1 l2 blk2 footprint i,
      {{{
        0 i length blk1.(structeq_block_fields)%Z
        footprint !! l1 = Some blk1
        footprint !! l2 = Some blk2
        blk1.(structeq_block_tag) = blk2.(structeq_block_tag)
        length blk1.(structeq_block_fields) = length blk2.(structeq_block_fields)
        structeq_footprint footprint
         j fld1 fld2,
          blk1.(structeq_block_fields) !! j = Some fld1
          blk2.(structeq_block_fields) !! j = Some fld2
          i j
          val_structeq footprint fld1.(structeq_field_val) fld2.(structeq_field_val)
        
      }}}
        structeq_aux #l1 #l2 #i
      {{{
        b
      , RET #b;
        (if b then val_structeq else val_structneq) footprint #l1 #l2
        structeq_footprint footprint
      }}}
    ) (
       l1 blk1 gen2 tag2 vs2 footprint i,
      let v2 := ValBlock gen2 tag2 vs2 in
      {{{
        0 i length vs2%Z
        footprint !! l1 = Some blk1
        blk1.(structeq_block_tag) = tag2
        length blk1.(structeq_block_fields) = length vs2
        0 < length vs2
        val_traversable footprint v2
        structeq_footprint footprint
         j fld1 v2,
          blk1.(structeq_block_fields) !! j = Some fld1
          vs2 !! j = Some v2
          i j
          val_structeq footprint fld1.(structeq_field_val) v2
        
      }}}
        structeq_aux #l1 v2 #i
      {{{
        b
      , RET #b;
        (if b then val_structeq else val_structneq) footprint #l1 v2
        structeq_footprint footprint
      }}}
    ) (
       gen1 tag1 vs1 l2 blk2 footprint i,
      let v1 := ValBlock gen1 tag1 vs1 in
      {{{
        0 i length vs1%Z
        footprint !! l2 = Some blk2
        tag1 = blk2.(structeq_block_tag)
        length vs1 = length blk2.(structeq_block_fields)
        0 < length vs1
        val_traversable footprint v1
        structeq_footprint footprint
         j v1 fld2,
          vs1 !! j = Some v1
          blk2.(structeq_block_fields) !! j = Some fld2
          i j
          val_structeq footprint v1 fld2.(structeq_field_val)
        
      }}}
        structeq_aux v1 #l2 #i
      {{{
        b
      , RET #b;
        (if b then val_structeq else val_structneq) footprint v1 #l2
        structeq_footprint footprint
      }}}
    ) (
       gen1 tag1 vs1 gen2 tag2 vs2 footprint i,
      let v1 := ValBlock gen1 tag1 vs1 in
      let v2 := ValBlock gen2 tag2 vs2 in
      {{{
        0 i length vs1%Z
        tag1 = tag2
        length vs1 = length vs2
        0 < length vs1
        val_traversable footprint v1
        val_traversable footprint v2
        structeq_footprint footprint
         j v1 v2,
          vs1 !! j = Some v1
          vs2 !! j = Some v2
          i j
          val_structeq footprint v1 v2
        
      }}}
        structeq_aux v1 v2 #i
      {{{
        b
      , RET #b;
        (if b then val_structeq else val_structneq) footprint v1 v2
        structeq_footprint footprint
      }}}
    ).
  Lemma structeq𑁒spec {v1 v2} footprint :
    val_traversable footprint v1
    val_traversable footprint v2
    {{{
      structeq_footprint footprint
    }}}
      v1 = v2
    {{{
      b
    , RET #b;
      (if b then val_structeq else val_structneq) footprint v1 v2
      structeq_footprint footprint
    }}}.
End zoo_G.

#[global] Opaque structeq.


Fixpoint val_abstract v :=
  match v with
  | ValBool _
  | ValInt _
      True
  | ValBlock Nongenerative _ vs
      Forall' val_abstract vs
  | _
      False
  end.
#[global] Arguments val_abstract !_ / : assert.

Lemma val_abstract_traversable v :
  val_abstract v
  val_traversable v.

Lemma val_compatible_refl_abstract footprint v1 v2 :
  val_abstract v1
  val_abstract v2
  v1 v2
  val_compatible footprint v1 v2 = true.

Lemma val_structeq_abstract_1 footprint v1 v2 :
  val_abstract v1
  val_abstract v2
  val_structeq footprint v1 v2
  v1 v2.
Lemma val_structeq_abstract_2 v1 v2 :
  val_abstract v1
  val_abstract v2
  v1 v2
  val_structeq v1 v2.
Lemma val_structeq_abstract v1 v2 :
  val_abstract v1
  val_abstract v2
  val_structeq v1 v2
  v1 v2.

Lemma val_structneq_abstract v1 v2 :
  val_abstract v1
  val_abstract v2
  val_structneq v1 v2
  v1 v2.

Lemma structeq𑁒spec_abstract `{zoo_G : !ZooG Σ} {v1 v2} :
  val_abstract v1
  val_abstract v2
  {{{
    True
  }}}
    v1 = v2
  {{{
    b
  , RET #b;
    (if b then (≈) else (≉)) v1 v2
  }}}.