From Coq Require Import Lia List Init.Nat.
From PCF2.Autosubst Require Import pcf_2.
From PCF2 Require Import CE SATIS pcf_2_system pcf_2_utils preliminaries.
From PCF2.external Require Import SR.
From PCF2.external.Synthetic Require Import Definitions DecidabilityFacts.
Set Default Goal Selector "!".

Lemma T_list_to_fun a:
    T a = list_to_fun (base_context (2 * length a + 2)) Base.
Proof.
    induction a as [| a' a IH].
    1: easy. unfold T. assert (H: 2 * length (a' :: a) + 2 = S( S(2 * length a + 2))) by now cbn; lia.
    rewrite H. unfold base_fun, base_context. rewrite iter_succ_l, iter_succ_l, iter_succ_l, iter_succ_l.
    assert (eq: iter (2 * length a + 2) (fun T0 : ty => Base T0) Base = T a) by easy.
    now rewrite eq, IH.
Qed.

Lemma make_rule_types_length R:
    length (make_rule_types R) = length R.
Proof.
    unfold make_rule_types. now rewrite map_length.
Qed.

Lemma rule_some n R:
    n < length R -> exists r, nth_error R n = Some r /\ nth_error (make_rule_types R) n = Some (rule_type r).
Proof.
    intros H. rewrite <- nth_error_Some in H. destruct (nth_error R n) eqn: e.
    2: easy. exists p. split. 1: easy. erewrite nth_error_nth' with (d := rule_type (nil, nil)).
    - unfold make_rule_types. rewrite map_nth. f_equal. f_equal.
        erewrite nth_error_nth' with (d := (nil, nil) : str* str) in e.
        1: now injection e.
        rewrite <- nth_error_Some. intros H1. now rewrite H1 in e.
    - rewrite make_rule_types_length. rewrite <- nth_error_Some. intros H1. now rewrite H1 in e.
Qed.

Lemma enc_typed enc v a:
    word_encoding enc v -> nil enc a : T a.
Proof.
    intros enc_val. now destruct (enc_val a).
Qed.

Lemma rule_enc_equiv (enc: str -> tm) (r: str * str) (F1 F2: tm):
    rule_encoding enc r F1 -> rule_encoding enc r F2 -> nil F1 F2: rule_type r.
Proof.
    intros H1 H2. split.
    - apply H1; now destruct H2.
    - apply H2; now destruct H1.
Qed.

Lemma rule_enc_word_enc_closed enc f R a v:
    word_encoding enc v -> fun_rule_encoding enc R f -> Forall (fun t0 : tm => closed t0) (map f R ++ enc a :: nil).
Proof.
    intros enc_val f_enc.
    rewrite Forall_forall.
    intros t Ht. specialize (in_app_or _ _ _ Ht).
    intros [H1 | H1].
    - destruct (In_nth_error _ _ H1) as [n H2].
        assert (H3: n < length R).
        { rewrite <- nth_error_Some. intros H3. rewrite nth_error_map in H2. now rewrite H3 in H2. }
        destruct (rule_some _ _ H3) as [r [Hr1 Hr2]]. rewrite nth_error_map, Hr1 in H2.
        injection H2; intros ?; subst. enough (H4: rule_encoding enc r (f r)).
        { eapply typed_empty_closed. destruct H4; eauto. }
        unfold fun_rule_encoding in f_enc. rewrite Forall_forall in f_enc. apply f_enc.
        eapply nth_error_In. eauto.
    - firstorder; subst. eapply typed_empty_closed. eapply enc_typed. eauto.
Qed.

Lemma rule_enc_word_enc_closed' enc f R a v:
     word_encoding enc v -> fun_rule_encoding enc R f -> Forall (fun t : tm => closed t) (enc a :: map f (rev R)).
Proof.
    intros enc_v enc_f. assert (H: enc a :: map f (rev R) = rev (map f R ++ enc a :: nil)).
    1: simpl_list; now rewrite map_rev.
    rewrite H. apply Forall_rev. now eapply rule_enc_word_enc_closed.
Qed.

Lemma t_subst_well_typed enc R f a b v:
    word_encoding enc v-> fun_rule_encoding enc R f -> base_context (2 * length b + 2) (t_subst enc R f a b) [(base_context (2*(length b) + 2)) ++ ((make_rule_types R)) ++ (cons (T a) nil)].
Proof.
    intros enc_val enc_f n A H. unfold t_subst.
    destruct (leb n (2 * length b + 1)) eqn: e1.
    - rewrite leb_le_true in e1. rewrite nth_error_app1 in H.
        + now constructor.
        + rewrite length_base_context. lia.
    - assert (H1: base_context (2 * length b + 2) = nil ++ base_context (2 * length b + 2)) by easy.
        rewrite leb_le_false in e1. destruct (leb n (2 * length b + 1 + length R)) eqn: e2.
        + rewrite leb_le_true in e2. destruct (nth_error R (n - (2 * length b + 2))) eqn: e3.
        * rewrite H1. apply type_weakening with (Γ := nil).
            rewrite nth_error_app2 in H.
            2: rewrite length_base_context; lia.
            destruct (arith_technical _ _ _ e1 e2) as [n' [eq1 eq2]].
            rewrite length_base_context in H.
            rewrite nth_error_app1 in H. 2: rewrite make_rule_types_length; lia.
            rewrite eq2 in H. assert (eq3: (2 * length b + 1 + 1 + n' - (2 * length b + 2)) = n') by lia.
            rewrite eq3 in H. destruct (rule_some n' R eq1) as [r [Hr1 Hr2]].
            rewrite Hr2 in H. injection H. intros ?; subst.
            rewrite eq3 in e3. rewrite Hr1 in e3. injection e3. intros ?; subst.
            unfold fun_rule_encoding in enc_f. rewrite Forall_forall in enc_f. specialize (enc_f p).
            assert (H3: In p R).
            { eapply nth_error_In; now eauto. }
            firstorder.
        * enough (length R <= n - (2 * length b + 2)) by lia.
            now rewrite <- nth_error_None.
        + rewrite leb_le_false in e2. destruct (leb n (2 * length b + 2 + length R)) eqn: e3.
        * rewrite leb_le_true in e3. assert (n = 2 * length b + 2 + length R) by lia.
            rewrite nth_error_app2, nth_error_app2 in H.
            2: rewrite length_base_context, make_rule_types_length; lia.
            2: rewrite length_base_context; lia.
            assert (eq1: n = 2 * length b + 2 + length R) by lia.
            assert (eq2: 2 * length b + 2 + length R - (2 * length b + 2) - length R = 0) by lia.
            rewrite eq1, length_base_context, make_rule_types_length, eq2 in H. cbn in H. assert (A = T a).
            { cbn in *. now injection H. } subst.
            rewrite H1. apply type_weakening. eapply enc_typed; eauto.
        * rewrite leb_le_false in e3.
            assert (n < length (base_context (2 * length b + 2) ++ make_rule_types R ++ T a :: nil)).
            { rewrite <- nth_error_Some. intros H2. now rewrite H2 in H. }
            rewrite app_length, app_length, length_base_context, make_rule_types_length in H0. cbn in H0.
            lia.
Qed.

Lemma t_subst_map_var_id (enc: str -> tm) (R: list (str * str)) (f: str * str -> tm) (a b: str):
    map (Subst_tm (t_subst enc R f a b)) (var_seq 0 (2 * length b + 2)) = (var_seq 0 (2 * length b + 2)).
Proof.
    apply list_eq.
    1: now simpl_list.
    intros n s t Hs Ht.
    assert (n < length (var_seq 0 (2 * length b + 2))).
    { apply nth_error_Some. now destruct (nth_error (var_seq 0 (2 * length b + 2)) n). }
    specialize (var_seq_map_nth n 0 _ (t_subst enc R f a b) H). intros eq. rewrite length_var_seq in eq. rewrite eq in Hs.
    injection Hs. intros eq'. subst. rewrite length_var_seq in H. unfold t_subst.
    destruct (leb n (2 * length b + 1)) eqn: e.
    - rewrite var_seq_nth in Ht. 2: easy.
        now injection Ht.
    - rewrite leb_le_false in e. lia.
Qed.

Lemma t_subst_map_var_id' (enc: str -> tm) (R: list (str * str)) (f: str * str -> tm) (a b: str):
    map (Subst_tm (t_subst enc R f a b)) (rev (var_seq 0 (2 * length b + 2))) = rev (var_seq 0 (2 * length b + 2)).
Proof.
    now rewrite map_rev, t_subst_map_var_id.
Qed.

Lemma t_subst_map_var_rules (enc: str -> tm) (R: list (str * str)) (f: str * str -> tm) (a b: str):
  map (Subst_tm (t_subst enc R f a b)) (var_seq (2 * length b + 2) (length R + 1)) = ((map f R) ++ (enc a :: nil)).
Proof.
  apply list_eq.
  - simpl_list. now rewrite length_var_seq.
  - intros n s t Hs Ht. rewrite nth_error_map in Hs.
    assert (n < length (var_seq (2 * length b + 2) (length R + 1))).
    {apply nth_error_Some. now destruct (nth_error (var_seq (2 * length b + 2) (length R + 1)) n). }
    rewrite length_var_seq in H.
    rewrite (var_seq_nth _ _ _ H) in Hs. cbn in Hs. assert (eq: length b + (length b + 0) + 2 = 2 *length b + 2) by lia. injection Hs. intros eq'. rewrite eq in eq'. rewrite <- eq'.
    unfold t_subst. destruct (leb (2 * length b + 2 + n) (2 * length b + 1)) eqn: e.
    1: rewrite leb_le_true in e; lia.
    destruct (leb (2*length b + 2 + n) (2 * length b + 1 + length R)) eqn: e'.
    + assert (eq'': 2 * length b + 2 + n - (2 * length b + 2) = n) by lia. rewrite eq''.
      rewrite leb_le_true in e'. destruct (nth_error R n) eqn: e''.
      * rewrite nth_error_app1 in Ht. 2: simpl_list; lia.
        assert (n < length R) by lia.
        rewrite nth_error_map in Ht. rewrite e'' in Ht. cbn in Ht. now injection Ht.
      * rewrite nth_error_None in e''. lia.
    + destruct (leb (2 * length b + 2 + n) (2 * length b + 2 + length R)) eqn: e''.
      * rewrite leb_le_true in e''. rewrite leb_le_false in e'. assert (n = length R) by lia.
        rewrite nth_error_app2 in Ht. 2: simpl_list; lia.
        destruct (n - length (map f R)) eqn: e'''.
        -- now injection Ht.
        -- rewrite map_length in e'''. lia.
      * rewrite leb_le_false in e''. lia.
Qed.

Lemma t_subst_map_var_rules' (enc: str -> tm) (R: list (str * str)) (f: str * str -> tm) (a b: str):
  map (Subst_tm (t_subst enc R f a b)) (rev (var_seq (2 * length b + 2) (length R + 1))) = (enc a :: (map f (rev R))).
Proof.
  rewrite map_rev, t_subst_map_var_rules. simpl_list. now rewrite <- map_rev.
Qed.

Lemma t_subst_initial enc R f W0 W:
  (t_subst enc R f W0 W)(2 * length W + length R + 2) = enc W0.
Proof.
  unfold t_subst. destruct (leb (2 * length W + length R + 2) (2 * length W + 1)) eqn: e1.
  1: rewrite leb_le_true in e1; lia.
  destruct (leb (2 * length W + length R + 2) (2 * length W + 1 + length R)) eqn: e2.
  1: rewrite leb_le_true in e2; lia.
  destruct (leb (2 * length W + length R + 2) (2 * length W + 2 + length R)) eqn: e3.
  1: easy. rewrite leb_le_false in e3; lia.
Qed.