StreamCalculus.conv_sqrt

Require Import MathClasses.interfaces.abstract_algebra.
Require Import StreamCalculus.conv_division.
Require Import StreamCalculus.sde.
Require Import StreamCalculus.special_streams.
Require Import StreamCalculus.streamscommon.
Require Import StreamCalculus.stream_addition.
Require Import StreamCalculus.conv_ring.
Require Import StreamCalculus.stream_equality_up_to.
Require Import StreamCalculus.stream_equality.
Require Import StreamCalculus.stream_equalities.
Require Import StreamCalculus.conv_ring_ring.

Require Import ArithRing Ring.
Set Implicit Arguments.

Section ConvRoot.
  Variable (T: Type).
  Context `{R: Ring T}.

  Context `{i2: @Inverse T Ae Amult Aone 2}.
  Definition inv2:=let (x):=i2 in
                   let (x,_):=x in
                   x.

  Definition conv_sqrt_tail_tail (s: unit -> Stream T) (sqrt_tail: (unit -> Stream T) -> Stream T): Stream T:=
    [inv2] * derive (s ()) - [inv2] * [inv2] * (s ()) * sqrt_tail s * conv_inverse ([inv2] * sqrt_tail s).

  Lemma conv_sqrt_tail_tail_causal:
    causal2 conv_sqrt_tail_tail.
  Proof.
    intros in1 in2 a1 a2 n Hi Ha.
    unfold conv_sqrt_tail_tail.
    specialize (eqUpTo_inverse_tail (Ha ())). intros.
    rewrite H.
    assert (StreamEqualUpTo n (conv_inverse ([inv2] * in1 a1)) (conv_inverse ([inv2] * in2 a2))).
    { apply eqUpTo_back. rewrite Hi.
      - reflexivity.
      - intros. now apply eqUpTo_back.
    }
    rewrite H0.
    rewrite Hi.
    - rewrite (eqUpTo_back (Ha ())).
      reflexivity.
    - intros. now apply eqUpTo_back.
  Qed.

  Definition sqrt_tail s:=corec (fun inp=>inp () * inv2) conv_sqrt_tail_tail (fun a=>s).

  Lemma sqrt_tail_rewrite s:
    (sqrt_tail s)=
                      (corec (λ inp : () T, inp () * inv2)
   (λ (s0 : () Stream T)
    (sqrt_tail : (() Stream T) Stream T),
    [inv2] * derive (s0 ()) -
    [inv2] * [inv2] * s0 () * sqrt_tail s0 *
    conv_inverse ([inv2] * sqrt_tail s0))
   (λ _ : (), s)).
  Proof.
    reflexivity.
  Qed.

  Definition conv_sqrt s:=1 + X * sqrt_tail s.

  Lemma sqrt_derive_sqrt_tail s:
    derive(1 + conv_sqrt s) = sqrt_tail s.
  Proof.
    rewrite addition_derive.
    unfold one. rewrite scalar_derive_null, left_identity.
    unfold conv_sqrt.
    rewrite addition_derive.
    unfold one. rewrite scalar_derive_null, left_identity.
    now rewrite derive_X_id.
  Qed.

  Global Instance sqrt_proper: Proper (StreamEqualUpTo n ==> StreamEqualUpTo n) conv_sqrt.
  Proof.
    intros.
    intros s1 s2 H.
    unfold conv_sqrt, sqrt_tail.
    destruct n; constructor.
    - repeat rewrite addition_derive_up_to.
      repeat rewrite derive_X_upto.
      rewrite corec_causal.
      + reflexivity.
      + now destruct R, ring_group, abgroup_group, group_monoid, monoid_semigroup.
      + cbn. intros.
        now apply eqUpTo_back.
      + intros in1 in2 H'.
        now rewrite H'.
      + exact conv_sqrt_tail_tail_causal.
    - repeat rewrite addition_pointwise.
      cbn. now repeat rewrite left_absorb.
  Qed.

  Lemma sqrt_tail_cor s:
    (derive (conv_sqrt s)) = (s * conv_apreinverse (1 + conv_sqrt s) inv2).
  Proof.
    unfold conv_sqrt at 1, sqrt_tail at 1.
    rewrite addition_derive.
    unfold one at 1. rewrite scalar_derive_null, left_identity.
    rewrite derive_X_id.
    rewrite (commutativity s).
    apply recursive_equality. split.
    - cbn. unfold conv_inverse.
      rewrite addition_pointwise. cbn.
      now rewrite left_absorb, right_identity, left_identity, commutativity.
    - rewrite corec_tail.
      + rewrite cleaner_mult_derive.
        cbn. unfold conv_inverse. rewrite addition_pointwise. cbn.
        assert ([(1 + 0 * - (derive (1 + conv_sqrt s) 0%nat * inv2)) * inv2] = [inv2]).
        { now rewrite left_absorb, right_identity, left_identity. }
        rewrite H.
        unfold conv_apreinverse.
        rewrite cleaner_mult_derive.
        rewrite scalar_derive_null, right_absorb, right_identity.
        rewrite conv_inverse_tail_cor.
        unfold conv_sqrt_tail_tail.
        rewrite <- sqrt_tail_rewrite.
        rewrite sqrt_derive_sqrt_tail.
        rewrite commutativity.
        rewrite minus_eq_minusOne at 1.
        rewrite (minus_eq_minusOne (sqrt_tail s * [inv2])).
        rewrite associativity.
        rewrite (commutativity ([inv2] * [inv2] * s)).
        rewrite associativity.
        rewrite <- (commutativity (conv_inverse ([inv2] * sqrt_tail s))).
        rewrite associativity.
        rewrite associativity.
        rewrite associativity.
        rewrite associativity.
        rewrite associativity.
        rewrite (commutativity (-1 * sqrt_tail s * [inv2])).
        repeat rewrite associativity.
        now rewrite (commutativity ([inv2])).
        exact R.
        exact R.
      + now destruct R, ring_group, abgroup_group, group_monoid, monoid_semigroup.
      + intros in1 in2 H.
        now rewrite H.
      + exact conv_sqrt_tail_tail_causal.
    Qed.

  Lemma sqrt_cor s:
    (conv_sqrt s) * (conv_sqrt s) = (1 + X * s).
  Proof.
    apply recursive_equality. split.
    - unfold conv_sqrt.
      rewrite addition_pointwise. cbn.
      rewrite addition_pointwise. cbn.
      repeat rewrite left_absorb.
      now repeat rewrite right_identity.
    - rewrite addition_derive, derive_X_id.
      unfold one, streamOne.
      rewrite scalar_derive_null.
      rewrite left_identity.
      rewrite cleaner_mult_derive.
      rewrite sqrt_tail_cor.
      unfold conv_sqrt at 3.
      rewrite addition_pointwise. cbn.
      assert ([1 + 0 * (s 0%nat * inv2)] = 1).
      {now rewrite left_absorb, right_identity. }
      rewrite H, left_identity.
      assert (s * conv_apreinverse (1 + conv_sqrt s) inv2 = s * conv_apreinverse (1 + conv_sqrt s) inv2 * 1).
      { now rewrite right_identity. }
      rewrite H0 at 2.
      rewrite <- (distribute_l (s * conv_apreinverse (1 + conv_sqrt s) inv2)).
      rewrite (commutativity (conv_sqrt s)).
      rewrite <- associativity.
      rewrite (commutativity (conv_apreinverse (1 + conv_sqrt s) inv2)).
      rewrite conv_apreinverse_cor.
      * now rewrite right_identity.
      * exact R.
      * unfold conv_sqrt.
        repeat rewrite addition_pointwise. cbn.
        rewrite left_absorb, right_identity.
        unfold inv2.
        now destruct i2, ex_inv.
    Qed.
End ConvRoot.

Section SqrtUniqueness.
  Variable (T: Type).
  Context `{R: Ring T}.
  Context `{i2: @Inverse T Ae Amult Aone 2}.

  Context `{null_division_free: forall a b, a * b = 0 -> a = 0 \/ b = 0}.

  Lemma sqrt1 s:
    s * s = 1 -> s = 1 \/ s = -1.
  Proof.
    intros.
    assert (0 = -1 * s + 1 * s).
      { rewrite <- distribute_r, negate_l, left_absorb.
       - reflexivity.
       - now destruct R, ring_group. }
    assert ((s-1)*(s+1) = 0).
    { rewrite distribute_l, distribute_r, right_identity, H.
      rewrite associativity.
      rewrite commutativity.
      repeat rewrite associativity.
      rewrite negate_l, left_identity.
      now rewrite H0, left_identity.
      now destruct R, ring_group.
    }
    assert (1+0=1). { now rewrite right_identity. }
    assert (-1+0= -1). { now rewrite right_identity. }
    destruct (null_division_free H1);
    [ left; rewrite <- H2| right; rewrite <- H3]
    ;rewrite <- H4
    ;rewrite commutativity, <- associativity
    ;[rewrite negate_l | rewrite negate_r]
    ;[now rewrite right_identity | now destruct R, ring_group| now rewrite right_identity |now destruct R, ring_group ].
  Qed.

  Lemma eqUpTo_fromPointwise' s1 s2:
    s1 = s2 -> (forall n, StreamEqualUpTo n s1 s2).
  Proof.
    apply streamEqualUpTo_streamEquality.
  Qed.

  Lemma minusOne_squared':
    -1 * -streamOne = 1.
  Proof.
    specialize (minus_eq_minusOne (- streamOne)). intros.
    rewrite <- H.
    rewrite double_negation. reflexivity.
  Qed.

  Lemma sqrt_uniqueness_up_to q r:
    (r * r = 1 + X * q) -> (forall n, (StreamEqualUpTo n r (conv_sqrt q))) \/ (forall n, StreamEqualUpTo n r (-conv_sqrt q)).
  Proof.
    intros.
    assert ( (r * r) 0%nat = 1).
    { rewrite (H 0%nat).
      rewrite addition_pointwise. cbn.
      now rewrite left_absorb, right_identity.
    }
    assert (r 0%nat * r 0%nat = 1).
    { rewrite <- H0.
      now cbn.
    }
    destruct (sqrt1 H1).
    - left.
      induction n; constructor.
      + enough (StreamEqualUpTo n (derive (1 + X * q)) (derive r * (1 + r))).
        { rewrite (eqUpTo_fromPointwise' (sqrt_tail_cor q)).
          rewrite addition_derive_up_to in H3.
          unfold one at 1 in H3.
          rewrite scalar_derive_upto, plus_left_absorb in H3.
          rewrite derive_X_upto in H3.
          rewrite H3 at 1.
          rewrite IHn at 3.
          apply streamEqualUpTo_streamEquality.
          rewrite <- associativity.
          rewrite (conv_apreinverse_cor (1+ conv_sqrt q)), right_identity.
          * reflexivity.
          * unfold conv_sqrt.
            repeat rewrite addition_pointwise. cbn.
            rewrite left_absorb, right_identity.
            unfold inv2. now destruct i2, ex_inv.
          * exact R.
        }
        apply streamEqualUpTo_streamEquality.
        rewrite <- H.
        rewrite cleaner_mult_derive.
        assert ([r 0%nat] = 1). { now rewrite H2. }
        now rewrite H3, (commutativity 1), <- distribute_l, (commutativity 1).
      + unfold conv_sqrt.
        rewrite addition_pointwise. cbn.
        now rewrite H2, left_absorb, right_identity.
    - right.
      induction n; constructor.
      + enough (StreamEqualUpTo n (derive (1 + X * q)) (- derive r * (1 - r))).
        { rewrite derive_minus_upto.
          rewrite (eqUpTo_fromPointwise' (sqrt_tail_cor q)).
          rewrite addition_derive_up_to in H3.
          unfold one at 1 in H3.
          rewrite scalar_derive_upto, plus_left_absorb in H3.
          rewrite derive_X_upto in H3.
          rewrite H3 at 1.
          rewrite IHn at 3.
          apply streamEqualUpTo_streamEquality.
          rewrite <- associativity.
          rewrite double_negation.
          rewrite (conv_apreinverse_cor (1+ conv_sqrt q)), right_identity.
          * now rewrite double_negation.
          * unfold conv_sqrt.
            repeat rewrite addition_pointwise. cbn.
            rewrite left_absorb, right_identity.
            unfold inv2. now destruct i2, ex_inv.
          * exact R.
        }
        apply streamEqualUpTo_streamEquality.
        rewrite <- H.
        rewrite cleaner_mult_derive.
        assert ([r 0%nat] = - 1).
        { rewrite H2.
          apply recursive_equality. split.
          - unfold one, streamOne, negate, streamNegate. now cbn.
          - rewrite <- minus_derive.
            unfold one.
            repeat rewrite scalar_derive_null.
            intros n'.
            unfold negate, streamNegate, zero, streamZero, nullStream.
            specialize minus_0. intros.
            unfold zero in H3.
            rewrite H3. reflexivity.
            exact R.
        }
        rewrite H3, (commutativity (derive r)), <- distribute_r, <- (commutativity (derive r)).
        assert ((1 - r) = - 1 * (r - 1)).
        { rewrite minus_eq_minusOne.
          rewrite <- minusOne_squared' at 1.
          rewrite distribute_l.
          rewrite commutativity.
          reflexivity. exact R. }
        rewrite H4.
        rewrite associativity.
        rewrite (commutativity (-derive r)).
        rewrite <- minus_eq_minusOne.
        now rewrite double_negation. exact R.
      + unfold conv_sqrt.
        unfold negate, streamNegate.
        rewrite addition_pointwise. cbn.
        now rewrite H2, left_absorb, right_identity.
  Qed.

  Lemma sqrt_uniqueness q r:
    (r * r = 1 + X * q) -> r = (conv_sqrt q) \/ r = (-conv_sqrt q).
  Proof.
    intros H.
    destruct (sqrt_uniqueness_up_to q r H);
    [left | right];
    now apply streamEqualUpTo_streamEquality.
  Qed.
End SqrtUniqueness.