(ns src.math)
;(alias 'cl 'clojure.core)

(defn neg-index [cnt x]
  (if (neg? x) (+ cnt x) x))

(defn swap [i j s]
  (let [[i j] (map (partial neg-index (count s)) [i j])]
    (assoc (assoc (vec s) i (nth s j))
           j (nth s i)
           )))

(defn search-non-zero [mat pivot]
  (let [width (count (first mat))
        height (count mat)]
    (first (for [j (range pivot (dec width))
                 i (range pivot height)
                 :when (not= 0 (nth (nth mat i) j))]
                [i j]))))

(defn sweep1 [mat pivot]
  (let [denom (nth (nth mat pivot) pivot)
        p-row (nth mat pivot)]
    (map (fn [i row]
           (if (= i pivot)
             (map #(/ % denom) row)
             (let [multiplier (/ (nth row pivot) denom)]
               (map #(- %1 %2) row (map #(* % multiplier) p-row))
               )))
         (iterate inc 0) mat)))

(defn lin-solve [mat]
  (let [width  (count (first mat))
        height (count mat)]
    (loop [pivot 0 m (vec (map vec mat))]
      (if (>= pivot (max (dec width) height))
        m
        (let [[row col] (search-non-zero mat pivot)]
          (cond (nil? row) m
                (not= col pivot) m
                :else (recur (inc pivot)
                             (sweep1 (swap pivot row m)
                                     pivot))))))))

(defn m*v [m v]
  (map #(apply + (map * % v)) m))

(defn s*m [s m]
  (map (fn [v] (map #(* s %) v))
       m))

(defn m-m [m0 m1]
  (map #(map - %1 %2)
       m0 m1))

(defn m*m [m0 m1]
  (map (fn [v0]
         (map (fn [v1]
                (apply + (map * v0 v1)))
              (apply map vector m1))) ; transpose
       m0))

(defn i-mat [size]
  (loop [i size acc []]
    (if (<= i 0)
      acc
      (recur (dec i)
             (cons (cons 1 (take (count acc) (repeat 0)))
                   (map #(cons 0 %) acc)
                   )))))

(defn inv-mat [m]
  (map #(drop (count m) %)
       (lin-solve (map concat m (i-mat (count m))))
       ))

(defn tref [tuple & idxs]
  (reduce #(nth %1 %2) tuple idxs))

(defn tfassoc [x idxs func]
  (if (empty? idxs)
    (func x)
    (assoc (vec x)
           (first idxs)
           (tfassoc (nth x (first idxs))
                    (rest idxs)
                    func))))

(defn t+2 [x y]
  (cond (and (not (coll? x))
             (not (coll? y)))
        (+ x y)

        :else
        (vec (map t+2 x y))
        ))

(defn t+ [x & xs]
  (reduce t+2 x xs))

(defn t-2 [x y]
  (cond (and (not (coll? x))
             (not (coll? y)))
        (- x y)

        :else
        (vec (map t-2 x y))
        ))

(defn t-
  ([x] (if (coll? x)
         (vec (map t- x))
         (- x)
         ))
  ([x & xs] (reduce t-2 x xs))
  )

(defn s*t [x y]
  (if (coll? y)
    (vec (map #(s*t x %) y))
    (* x y)
    ))

(defn diff-num [f idxs x dx]
  (letfn [(rec [idxs]
            (let [el (apply tref x idxs)]
              (if (coll? el)
                (vec (map #(rec (conj idxs %))
                          (range (count el))
                          ))
                (s*t (/ 1.0 dx)
                     (t- (f (tfassoc x idxs #(+ % (/ dx 2.0))))
                         (f (tfassoc x idxs #(- % (/ dx 2.0))))
                         )))))]
    (rec (vec idxs))
    ))

; creates acceleration function from Lagrangean
; q^i.. =
;  (d^2L/(dq^j dq^i))^(-1) (dL/dq^i - d^2L/dtdq^i. - d^2L/dq^j dq^i. q^j)
(defn L->accer [dt dq dqdot L D]
  (fn [x] ; x -> [t q qdot]
    (m*v (inv-mat (diff-num (fn [x] (diff-num L [2] x dqdot))
                            [2] x dqdot)) ; symmetry
         (t- (diff-num L [1] x dq)
             (diff-num (fn [x] (diff-num L [2] x dqdot))
                       [0] x dt)
             (m*v (apply map vector ; transpose
                         (diff-num (fn [x] (diff-num L [2] x dqdot))
                                   [1] x dq))
                  (nth x 2))
             (diff-num D [2] x dqdot)
             ))))

; dx_i/dt = f_i(x)
; x^i(t+h) = x^i(t) + f^i(x(t+h))h ; implicit Eular method
; x^i(t+h) = x^i(t) + f^i(x(t) + x(t+h) - x(t))h
; x^i(t+h) = x^i(t) + f^i(x(t))h + df^i(x(t))/dx^j * (x^j(t+h) - x^j(t)) * h
; x^i(t+h) - h*df^i(x(t))/dx^j*x^j(t+h)
;  = x^i(t) + f^i(x(t))h - h*df^i(x(t))/dx^j*x^j(t)
(defn eular-implicit [phys xis dt dxis]
  (let [jaco (apply map vector ; transpose
                    (diff-num phys [] xis dxis))
        lhs (m-m (i-mat (count xis))
                 (s*m dt jaco))
        rhs (t+ xis
                (map #(* % dt) (phys xis))
                (t- (map #(* % dt) (m*v jaco xis)))
                )
        solved (lin-solve (map #(conj %1 %2)
                               (map vec lhs) rhs))]
    (vec (map #(nth % (count solved))
              solved))))
