implemented downsampling and distance measuring/updating

Features:

- Tests!
- maintain index of all training-cases
- update distance of training-cases after every evaluation (only update cases that are in the downsample)
- when an individual passes all of the DS cases, the population is re-evaluated on all training-cases to see if they pass all of them. If so, evolution is completed and training is complete. (they are then tested on held out test set and performance is reported)
This commit is contained in:
Ryan Boldi 2022-03-12 13:59:47 -05:00
parent 83c7e440f6
commit 287799f194
4 changed files with 120 additions and 16 deletions

View File

@ -1,18 +1,60 @@
(ns propeller.downsample)
(ns propeller.downsample
(:require [propeller.tools.math :as math]))
(defn assign-indices-to-data
"assigns an index to each training case in order to differentiate them when downsampling"
[{:keys [training-data]}]
[training-data]
(map (fn [data-map index]
(let [data-m (if (map? data-map) data-map (assoc {} :data data-map))] ;if data is not in a map, make it one
(assoc data-m :index index)))
training-data (range (count training-data))))
(defn initialize-case-distances
[{:keys [training-data population-size]}]
(map #(assoc % :distances (vec (repeat (count training-data) population-size))) training-data))
(defn select-downsample-random
"Selects a downsample from the training cases and returns it"
[training-data {:keys [downsample-rate]}]
(take (int (* downsample-rate (count training-data))) (shuffle training-data)))
(defn update-case-data
"updates the case metadata field of argmap, should be called after evaluation of individuals"
[argmap])
(defn get-distance-between-cases
"returns the distance between two cases given a list of individual error vectors, and the index these
cases exist in the error vector"
[error-lists case-index-1 case-index-2]
(if (or (< (count (first error-lists)) case-index-1)
(< (count (first error-lists)) case-index-2)
(neg? case-index-1) (neg? case-index-2))
(count error-lists) ;return the max distance
(let [errors-1 (map #(nth % case-index-1) error-lists)
errors-2 (map #(nth % case-index-2) error-lists)]
;compute distance between errors-1 and errors-2
(reduce + (map (fn [e1 e2] (math/abs (- (math/step e1) (math/step e2)))) errors-1 errors-2)))))
(defn update-at-indices
"merges two vectors at the indices provided by a third vector"
[big-vec small-vec indices]
(->> big-vec
(map-indexed (fn [idx itm] (let [index (.indexOf indices idx)]
(if (not= -1 index) (nth small-vec index) itm))))
vec))
(defn merge-map-lists-at-index
"merges two lists of maps, replacing the maps in the big
list with their corresponding (based on index) maps in the small list"
[big-list small-list]
(map
#(let [corresponding-small (some (fn [c] (when (= (:index %) (:index c)) c)) small-list)]
(if (nil? corresponding-small) % corresponding-small))
big-list))
(defn update-case-distances
"updates the case distance field of training-data list, should be called after evaluation of individuals
evaluated-pop should be a list of individuals that all have the :errors field with a list of this
individuals performance on the each case in the ds-data, in order"
[evaluated-pop ds-data training-data]
(let [ds-indices (map #(:index %) ds-data) errors (map #(:errors %) evaluated-pop)]
(merge-map-lists-at-index training-data
(map-indexed (fn [idx d-case]
(update-in d-case [:distances] #(update-at-indices %
(map (fn [other] (get-distance-between-cases errors idx other)) (range (count ds-indices))) ds-indices))) ds-data))))

View File

@ -47,7 +47,7 @@
population (mapper
(fn [_] {:plushy (genome/make-random-plushy instructions max-initial-plushy-size)})
(range population-size))
indexed-training-data (downsample/assign-indices-to-data argmap)]
indexed-training-data (downsample/assign-indices-to-data (downsample/initialize-case-distances argmap))]
(let [training-data (if (= (:parent-selection argmap) :ds-lexicase)
(downsample/select-downsample-random indexed-training-data argmap)
indexed-training-data)
@ -55,18 +55,32 @@
(mapper
(partial error-function argmap training-data)
population))
best-individual (first evaluated-pop)]
best-individual (first evaluated-pop)
best-individual-passes-ds (and (= (:parent-selection argmap) :ds-lexicase) (<= (:total-error best-individual) solution-error-threshold))
tot-evaluated-pop (when best-individual-passes-ds ;evaluate the whole pop on all training data
(sort-by :total-error
(mapper
(partial error-function argmap (:training-data argmap))
population)))
;;best individual on all training-cases
tot-best-individual (if best-individual-passes-ds (first tot-evaluated-pop) best-individual)]
(prn (first training-data))
(if (:custom-report argmap)
((:custom-report argmap) evaluated-pop generation argmap)
(report evaluated-pop generation argmap))
;;did the indvidual pass all cases in ds?
(when best-individual-passes-ds
(prn {:semi-success-generation generation}))
(cond
;; Success on training cases is verified on testing cases
(<= (:total-error best-individual) solution-error-threshold)
(or (and best-individual-passes-ds (<= (:total-error tot-best-individual) solution-error-threshold))
(and (not= (:parent-selection argmap) :ds-lexicase)
(<= (:total-error best-individual) solution-error-threshold)))
(do (prn {:success-generation generation})
(prn {:total-test-error
(:total-error (error-function argmap (:testing-data argmap) best-individual))})
(:total-error (error-function argmap (:testing-data argmap) tot-best-individual))})
(when (:simplification? argmap)
(let [simplified-plushy (simplification/auto-simplify-plushy (:plushy best-individual) error-function argmap)]
(let [simplified-plushy (simplification/auto-simplify-plushy (:plushy tot-best-individual) error-function argmap)]
(prn {:total-test-error-simplified (:total-error (error-function argmap (:testing-data argmap) (hash-map :plushy simplified-plushy)))}))))
;;
(>= generation max-generations)
@ -79,4 +93,4 @@
(first evaluated-pop))
(repeatedly population-size
#(variation/new-individual evaluated-pop argmap)))
(update-case-metadata evaluated-pop))))))
(downsample/update-case-distances evaluated-pop training-data indexed-training-data))))))

View File

@ -26,4 +26,5 @@
[pop argmap]
(case (:parent-selection argmap)
:tournament (tournament-selection pop argmap)
:lexicase (lexicase-selection pop argmap)))
:lexicase (lexicase-selection pop argmap)
:ds-lexicase (lexicase-selection pop argmap)))

View File

@ -86,12 +86,12 @@
(t/deftest assign-indices-to-data-test
(t/testing "assign-indices-to-data"
(t/testing "should return a map of the same length"
(t/is (= (count (ds/assign-indices-to-data {:training-data (range 10)})) 10))
(t/is (= (count (ds/assign-indices-to-data {:training-data (range 0)})) 0)))
(t/is (= (count (ds/assign-indices-to-data (range 10))) 10))
(t/is (= (count (ds/assign-indices-to-data (range 0))) 0)))
(t/testing "should return a map where each element has an index key"
(t/is (every? #(:index %) (ds/assign-indices-to-data {:training-data (map #(assoc {} :input %) (range 10))}))))
(t/is (every? #(:index %) (ds/assign-indices-to-data (map #(assoc {} :input %) (range 10))))))
(t/testing "should return distinct indices"
(t/is (= (map #(:index %) (ds/assign-indices-to-data {:training-data (range 10)})) (range 10))))))
(t/is (= (map #(:index %) (ds/assign-indices-to-data (range 10))) (range 10))))))
(t/deftest select-downsample-random-test
(t/testing "select-downsample-random"
@ -109,3 +109,50 @@
(t/testing "should not return more elements than available"
(t/is (= (count (ds/select-downsample-random (range 10) {:downsample-rate 2})) 10))
(t/is (= (count (ds/select-downsample-random (range 10) {:downsample-rate 1.5})) 10)))))
(t/deftest get-distance-between-cases-test
(t/testing "get-distance-between-cases"
(t/testing "should return correct distance"
(t/is (= 3 (ds/get-distance-between-cases '((0 1 1) (0 1 1) (1 0 1)) 0 1))))
(t/testing "should return 0 for the distance of a case to itself"
(t/is (= 0 (ds/get-distance-between-cases '((0 1 1) (0 1 1) (1 0 1)) 0 0))))
(t/testing "should work for non binary values (0 is solved)"
(t/is (= 1 (ds/get-distance-between-cases '((0 2 2) (0 2 2) (1 0 50)) 1 2))))
(t/testing "should return the max distance if one of the cases does not exist"
(t/is (= 3 (ds/get-distance-between-cases '((0 1 1) (0 1 1) (1 0 1)) 0 4))))))
(t/deftest merge-map-lists-at-index-test
(t/testing "merge-map-lists-at-index"
(t/testing "works properly"
(t/is (= '({:index 0 :a 3 :b 2} {:index 1 :a 2 :b 3}) (ds/merge-map-lists-at-index '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) '({:index 1 :a 2 :b 3})))))
(t/testing "doesn't change big list if no indices match"
(t/is (= '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) (ds/merge-map-lists-at-index '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) '({:index 3 :a 2 :b 3})))))
(t/testing "doesn't fail on empty list"
(t/is (= '() (ds/merge-map-lists-at-index '() '()))))
(t/testing "shouldn't fail merging non-empty with empty"
(t/is (= '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) (ds/merge-map-lists-at-index '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) '()))))))
(t/deftest update-at-indices-test
(t/testing "update-at-indices"
(t/testing "should update at correct indices"
(t/is (= (ds/update-at-indices [1 2 3 4] [5] [0]) [5 2 3 4]))
(t/is (= (ds/update-at-indices [1 2 3 4] [5] [0]) [5 2 3 4])))
(t/testing "should update nothing if index list is empty"
(t/is (= (ds/update-at-indices [6 5 4 0 0] [] []) [6 5 4 0 0])))
(t/testing "should update nothing if index list is out of bounds"
(t/is (= (ds/update-at-indices [6 5 4 0 0] [4 5 1] [-1 5 6]) [6 5 4 0 0])))
(t/testing "should update only when indices are available (length mismatch)"
(t/is (= (ds/update-at-indices [6 5 4 0 0] [1 2 3 4] [0 1]) [1 2 4 0 0])))
(t/testing "should not care about index order"
(t/is (= (ds/update-at-indices [6 5 4 0 0] [2 1] [1 0]) [1 2 4 0 0])))
(t/testing "should work when input is a list"
(t/is (= (ds/update-at-indices '(6 5 4 0 0) '(2 1) '(1 0)) [1 2 4 0 0])))))
(t/deftest update-case-distances-test
(t/testing "update-case-distances"
(t/testing "should ..."
(t/is (= (ds/update-case-distances '({:errors (0 0)} {:errors (0 0)})
'({:index 3 :distances [2 2 2 2 2]} {:index 4 :distances [2 2 2 2 2]})
'({:index 0 :distances [2 2 2 2 2]} {:index 1 :distances [2 2 2 2 2]} {:index 2 :distances [2 2 2 2 2]} {:index 3 :distances [2 2 2 2 2]} {:index 4 :distances [2 2 2 2 2]}))
'({:index 0 :distances [2 2 2 2 2]} {:index 1 :distances [2 2 2 2 2]} {:index 2 :distances [2 2 2 2 2]}
{:index 3 :distances [2 2 2 0 0]} {:index 4 :distances [2 2 2 0 0]}))))))