(ns src.math)
;(alias 'cl 'clojure.core)

(defn neg-index [cnt x]
  (if (neg? x) (+ cnt x) x))

(defn swap [i j s]
  (let [[i j] (map (partial neg-index (count s)) [i j])]
    (assoc (assoc (vec s) i (nth s j))
           j (nth s i)
           )))

(defn search-non-zero [mat pivot]
  (let [width (count (first mat))
        height (count mat)]
    (first (for [j (range pivot (dec width))
                 i (range pivot height)
                 :when (not= 0 (nth (nth mat i) j))]
                [i j]))))

(defn sweep1 [mat pivot]
  (let [denom (nth (nth mat pivot) pivot)
        p-row (nth mat pivot)]
    (map (fn [i row]
           (if (= i pivot)
             (map #(/ % denom) row)
             (let [multiplier (/ (nth row pivot) denom)]
               (map #(- %1 %2) row (map #(* % multiplier) p-row))
               )))
         (iterate inc 0) mat)))

(defn lin-solve [mat]
  (let [width  (count (first mat))
        height (count mat)]
    (loop [pivot 0 m (vec (map vec mat))]
      (if (>= pivot (max (dec width) height))
        m
        (let [[row col] (search-non-zero mat pivot)]
          (cond (nil? row) m
                (not= col pivot) m
                :else (recur (inc pivot)
                             (sweep1 (swap pivot row m)
                                     pivot))))))))

(defn m*v [m v]
  (map #(apply + (map * % v)) m))

(defn s*m [s m]
  (map (fn [v] (map #(* s %) v))
       m))

(defn m-m [m0 m1]
  (map #(map - %1 %2)
       m0 m1))

(defn i-mat [size]
  (loop [i size acc []]
    (if (<= i 0)
      acc
      (recur (dec i)
             (cons (cons 1 (take (count acc) (repeat 0)))
                   (map #(cons 0 %) acc)
                   )))))

(defn inv-mat [m]
  (map #(drop (count m) %)
       (lin-solve (map concat m (i-mat (count m))))
       ))

(defn tref [tuple & idxs]
  (reduce #(nth %1 (inc %2))
          tuple idxs))

(defn tfassoc [x idxs func]
  (if (empty? idxs)
    (func x)
    (assoc (vec x)
           (inc (first idxs))
           (tfassoc (nth x (inc (first idxs)))
                    (rest idxs)
                    func))))

(defn t+2 [x y]
  (cond (and (not (coll? x))
             (not (coll? y)))
        (+ x y)

        (= (first x) (first y))
        (vec (cons (first x)
                   (map t+2 (rest x) (rest y))
                   ))

        :else nil
        ))

(defn t+ [x & xs]
  (reduce t+2 x xs))

(defn t-2 [x y]
  (cond (and (not (coll? x))
             (not (coll? y)))
        (- x y)

        (= (first x) (first y))
        (vec (cons (first x)
                   (map t-2 (rest x) (rest y))
                   ))

        :else nil
        ))

(defn t-
  ([x] (if (coll? x)
         (vec (cons (first x)
                    (map t- (rest x))
                    ))
         (- x)
         ))
  ([x & xs] (reduce t-2 x xs))
  )

(defn t*2 [x y]
  (if (coll? y)
    (let [type-pair [(and (coll? x) (first x))
                     (first y)
                     ]]
      (if (or (= type-pair '[up down])
              (= type-pair '[down up]))
        (apply t+
               (map t*2 (rest x) (rest y)))
        (vec (cons (first y)
                   (map #(t*2 x %) (rest y))
                   ))))
    (if (coll? x)
      (vec (cons (first x)
                 (map #(t*2 % y) (rest x))
                 ))
      (* x y)
      )))

; (t* a b c d) --> (t*2 a (t*2 b (t*2 c d))) ; right associative
(defn t* [x & xs]
  (let [[y & ys] (reverse (cons x xs))]
    (reduce #(t*2 %2 %1) y ys)))

;([][])[]=[]
; ->   []=([][])[]
;(()())[]=()
; ->   []=[[][]]()
;[[][]]()=[]
; ->   ()=(()())[]
;[()()]()=()
; ->   ()=[()()]
(defn tinv [x]
  (if (coll? x)
    (let [out-type ('{up down, down up} (first (nth x 1)))
          in-type  ('{up down, down up} (first x))
          transpose (fn [mat] (apply map vector mat))]
      (vec (cons out-type
                 (map #(vec (cons in-type %))
                      (transpose
                        (inv-mat
                          (transpose (map rest (rest x)))
                          ))))))
    (/ x)))


(def *dx* 0.001)

(defn partial-val [f idxs x]
  (letfn [(rec [idxs]
            (let [el (apply tref x idxs)]
              (if (coll? el)
                (vec (cons ('{up down, down up} (first x))
                           (map #(rec (conj idxs %))
                                (range (dec (count el)))
                                )))
                (t* (/ 1.0 *dx*)
                    (t- (f (tfassoc x idxs #(+ % (/ *dx* 2))))
                        (f (tfassoc x idxs #(- % (/ *dx* 2))))
                        )))))]
    (rec (vec idxs))
    ))


; Lagrangean equation
; D(d2 L o G[q]) - d1 L o G[q] = 0
; (D d2 L o G[q])D G[q] = d1 L o G[q]
; d0 d2 L o G[q] + (d1 d2 L o G[q])Dq + (d2 d2 L o G[q])D^2q = d1 L o G[q]
; D^2q =
;  (d2 d2 L o G[q])^(-1) (d1 L o G[q] - d0 d2 L o G[q] - (d1 d2 L o G[q])Dq)
; A = (d2 d2 L)^(-1) (d1 L - d0 d2 L - (d1 d2 L)I2)
(defn L->accer [L]
  (fn [x] ; x -> [up t q qdot]
    (t* (tinv (partial-val (fn [x] (partial-val L [2] x))
                           [2] x))
        (t- (partial-val L [1] x)
            (partial-val (fn [x] (partial-val L [2] x))
                         [0] x)
            (t* (partial-val (fn [x] (partial-val L [2] x))
                             [1] x)
                (nth x 3)
                )))))


; dx_i/dt = f_i(x)
; x^i(t+h) = x^i(t) + f^i(x(t+h))h ; implicit Eular method
; x^i(t+h) = x^i(t) + f^i(x(t) + x(t+h) - x(t))h
; x^i(t+h) = x^i(t) + f^i(x(t))h + df^i(x(t))/dx^j * (x^j(t+h) - x^j(t)) * h
; x~i(t+h) - h*df^i(x(t))/dx^j*x^j(t+h)
;  = x^i(t) + f^i(x(t))h - h*df^i(x(t))/dx^j*x^j(t)
(def *h* 0.02)

(defn next-iter [phys xis]
  (let [jaco (partial-val phys [] xis)
        lhs (m-m (i-mat (dec (count xis)))
                 (s*m *h*
                      (vec (apply map vector ; transpose
                                  (map rest (rest jaco))
                                  ))))
        rhs (rest (t+ xis
                      (t* *h* (phys xis))
                      (t- (t* *h* jaco xis))
                      ))
        solved (lin-solve (map #(conj %1 %2)
                               (map vec lhs) rhs))]
    (vec (cons (first xis)
               (map #(nth % (count solved))
                    solved)))))
