Library zoo.common.countable

From stdpp Require Export
  countable.

From zoo Require Import
  prelude.
From zoo.ltac2 Require
  Array
  Constr
  Control
  Ind
  List
  String
  Notations.
From zoo Require Import
  options.

Module solve_countable.
  Import Ltac2.Init.
  Import Ltac2.Notations.

  Ltac2 Type error :=
    [ Goal_invalid
    | Type_not_inductive
    | Type_empty
    | Internal_error
    ].

  Ltac2 error_to_string err :=
    match err with
    | Goal_invalid
        "invalid goal"
    | Type_not_inductive
        "not an inductive type"
    | Type_empty
        "empty inductive type"
    | Internal_error
        "internal error"
    end.
  Ltac2 error err :=
    let err := error_to_string err in
    let err := String.app "solve_countable: " err in
    Control.throw_invalid_argument err.

  Ltac2 Type info :=
    { info_type : constr
    ; info_inductive : inductive
    ; info_data : Ind.data
    ; info_instance : instance
    ; info_indices : constr array
    ; info_number_constructor : int
    ; info_constructors : constructor array
    ; info_fieldss : constr array array
    ; info_sum_inductive : inductive
    ; info_prod_inductive : inductive
    }.

  Ltac2 info_constructor info i :=
    Array.get (info.(info_constructors)) i.
  Ltac2 info_constructor_constr info i :=
    let ctor := info_constructor info i in
    Constr.Unsafe.make_constructor ctor (info.(info_instance)).
  Ltac2 info_fields info i :=
    Array.get (info.(info_fieldss)) i.
  Ltac2 info_arity info i :=
    Array.length (info_fields info i).

  Ltac2 info ty :=
    let (ind, indices) := Constr.head_tail ty in
    let (ind, inst) :=
      match Constr.Unsafe.kind ind with
      | Constr.Unsafe.Ind ind inst
          ind, inst
      | _
          error Type_not_inductive
      end
    in
    let data := Ind.data ind in
    let num_ctor := Ind.nconstructors data in
    if Int.equal num_ctor 0 then
      error Type_empty
    else
      let ctors :=
        Array.init num_ctor (fun i
          Ind.get_constructor data i
        )
      in
      let fldss :=
        Array.map (fun ctor
          let ctor :=
            Constr.Unsafe.make_app
              (Constr.Unsafe.make_constructor ctor inst)
              indices
          in
          let ctor_ty := Constr.type ctor in
          let flds := Constr.product_parameters ctor_ty in
          Array.of_list flds
        ) ctors
      in
      let sum_ind :=
        match Constr.Unsafe.kind '@sum with
        | Constr.Unsafe.Ind ind _
            ind
        | _
            error Internal_error
        end
      in
      let prod_ind :=
        match Constr.Unsafe.kind '@prod with
        | Constr.Unsafe.Ind ind _
            ind
        | _
            error Internal_error
        end
      in
      { info_type := ty
      ; info_inductive := ind
      ; info_data := data
      ; info_instance := inst
      ; info_indices := indices
      ; info_number_constructor := num_ctor
      ; info_constructors := ctors
      ; info_fieldss := fldss
      ; info_sum_inductive := sum_ind
      ; info_prod_inductive := prod_ind
      }.

  Ltac2 encode_branch info i :=
    let flds := info_fields info i in
    let num_fld := Array.length flds in
    let bdrs :=
      Array.map (fun fld
        Constr.Binder.make None fld
      ) flds
    in
    let bdrs := Array.to_list bdrs in
    let body :=
      List.init_foldl (fun acc i
        let fld := Constr.Unsafe.make_rel (Int.sub num_fld i) in
        Constr.Unsafe.make_app
          '@pair
          [|'_; '_; acc; fld|]
      ) 'tt num_fld
    in
    let body :=
      let ty := if Int.equal i 0 then 'unit else '_ in
      Constr.Unsafe.make_app
        '@inr
        [|ty; '_; body|]
    in
    let body :=
      List.init_foldl (fun acc _
        Constr.Unsafe.make_app
          '@inl
          [|'_; '_; acc|]
      ) body (Int.sub (Int.sub (info.(info_number_constructor)) i) 1)
    in
    Constr.Unsafe.make_lambdas bdrs body.
  Ltac2 encode_case info :=
    Constr.Unsafe.make_case_simple
      (info.(info_inductive))
      (info.(info_type))
      '_
      (Constr.Unsafe.make_rel 1)
      ( Array.init
          (info.(info_number_constructor))
          (encode_branch info)
      ).
  Ltac2 encode info :=
    Constr.Unsafe.make_lambda
      (Constr.Binder.make None (info.(info_type)))
      (encode_case info).

  Ltac2 extract_arguments_2 ty :=
    match Constr.Unsafe.kind ty with
    | Constr.Unsafe.App _ tys
        if Bool.neg (Int.equal (Array.length tys) 2) then
          error Internal_error
        else
          Array.get tys 0, Array.get tys 1
    | _
        error Internal_error
    end.
  Ltac2 rec decode_branch' info i j ty' :=
    if Int.equal j 0 then
      Constr.Unsafe.make_app
        '@Some
        [|info.(info_type)
        ; Constr.Unsafe.make_app
            (info_constructor_constr info i)
            ( Array.append
                (info.(info_indices))
                ( Array.init (info_arity info i) (fun j
                    Constr.Unsafe.make_rel (Int.add (Int.mul 2 j) 1)
                  )
                )
            )
        |]
    else
      let ty := info.(info_type) in
      let (ty'_1, ty'_2) := extract_arguments_2 ty' in
      
      let ty' := Constr.Unsafe.make_app '@prod [|ty'_1; ty'_2|] in
      Constr.Unsafe.make_case_simple
        (info.(info_prod_inductive))
        ty'
        '(option $ty)
        (Constr.Unsafe.make_rel 2)
        [|Constr.Unsafe.make_lambdas
            [ Constr.Binder.make None ty'_1
            ; Constr.Binder.make None ty'_2
            ]
            (decode_branch' info i (Int.sub j 1) ty'_1)
        |].
  Ltac2 decode_branch info i ty' :=
    Constr.Unsafe.make_lambda
      (Constr.Binder.make None ty')
      ( Constr.Unsafe.make_let
          (Constr.Binder.make None 'unit)
          'tt
          (decode_branch' info i (info_arity info i) ty')
      ).
  Ltac2 rec decode_case' info i ty' :=
    let ty := info.(info_type) in
    if Int.equal i (-1) then
      '(@None $ty)
    else
      let (ty'_l, ty'_r) := extract_arguments_2 ty' in
      
      let ty' := Constr.Unsafe.make_app '@sum [|ty'_l; ty'_r|] in
      Constr.Unsafe.make_case_simple
        (info.(info_sum_inductive))
        ty'
        '(option $ty)
        (Constr.Unsafe.make_rel 1)
        [|Constr.Unsafe.make_lambda
            (Constr.Binder.make None ty'_l)
            (decode_case' info (Int.sub i 1) ty'_l)
        ; decode_branch info i ty'_r
        |].
  Ltac2 decode_case info ty' :=
    decode_case'
      info
      (Int.sub (info.(info_number_constructor)) 1)
      ty'.
  Ltac2 decode info ty' :=
    Constr.Unsafe.make_lambda
      (Constr.Binder.make None ty')
      (decode_case info ty').

  Ltac2 main' info :=
    let encode := encode info in
    let ty' := Constr.product_result (Constr.type encode) in
    let decode := decode info ty' in
    let _ := Constr.type decode in
    refine (
      Constr.Unsafe.make_app
        '@inj_countable
        [|'_; '_; '_; '_; '_; encode; decode; '_|]
    ) ;
    Control.focus 1 2 (fun () ⇒ apply _);
    solve [intros []; auto].
  Ltac2 main () :=
    lazy_match! goal with
    | [|- Countable ?ty] ⇒
        main' (info ty)
    | [|- _] ⇒
        error Goal_invalid
    end.
End solve_countable.

Ltac solve_countable :=
  ltac2:(solve_countable.main ()).

Module tests.
  Record test_1 :=
    { test_1_1 : nat
    }.
  #[local] Instance test_1_eq_dec : EqDecision test_1 :=
    ltac:(solve_decision).
  #[local] Instance test_1_countable :
    Countable test_1.

  Record test_2 A1 A2 A3 :=
    { test_2_1 : A1
    ; test_2_2 : A2
    ; test_2_3 : A3
    }.
  #[local] Instance test_2_eq_dec `{!EqDecision A1, !EqDecision A2, !EqDecision A3} : EqDecision (test_2 A1 A2 A3) :=
    ltac:(solve_decision).
  #[local] Instance test_2_countable `{Countable A1, Countable A2, Countable A3} :
    Countable (test_2 A1 A2 A3).

  Variant test_3 :=
    | Test31 : test_3
    | Test32 : nat test_3
    | Test33 : nat bool test_3.
  #[local] Instance test_3_eq_dec : EqDecision test_3 :=
    ltac:(solve_decision).
  #[local] Instance test_3_countable :
    Countable test_3.

  Variant test_4 A1 A2 A3 :=
    | Test41 : test_4 A1 A2 A3
    | Test42 : nat test_4 A1 A2 A3
    | Test43 : nat bool test_4 A1 A2 A3
    | Test44 : A1 test_4 A1 A2 A3
    | Test45 : A2 test_4 A1 A2 A3
    | Test46 : A3 test_4 A1 A2 A3
    | Test47 : A1 A2 test_4 A1 A2 A3
    | Test48 : A1 A3 test_4 A1 A2 A3
    | Test49 : A2 A3 test_4 A1 A2 A3
    | Test410 : A1 A2 A3 test_4 A1 A2 A3.
  #[local] Instance test_4_eq_dec `{!EqDecision A1, !EqDecision A2, !EqDecision A3} : EqDecision (test_4 A1 A2 A3) :=
    ltac:(solve_decision).
  #[local] Instance test_4_countable `{Countable A1, Countable A2, Countable A3} :
    Countable (test_4 A1 A2 A3).
End tests.