diff --git a/src/propeller/downsample.cljc b/src/propeller/downsample.cljc index b1e8d02..3e24f26 100644 --- a/src/propeller/downsample.cljc +++ b/src/propeller/downsample.cljc @@ -1,18 +1,60 @@ -(ns propeller.downsample) +(ns propeller.downsample + (:require [propeller.tools.math :as math])) (defn assign-indices-to-data "assigns an index to each training case in order to differentiate them when downsampling" - [{:keys [training-data]}] + [training-data] (map (fn [data-map index] (let [data-m (if (map? data-map) data-map (assoc {} :data data-map))] ;if data is not in a map, make it one (assoc data-m :index index))) training-data (range (count training-data)))) +(defn initialize-case-distances + [{:keys [training-data population-size]}] + (map #(assoc % :distances (vec (repeat (count training-data) population-size))) training-data)) + (defn select-downsample-random "Selects a downsample from the training cases and returns it" [training-data {:keys [downsample-rate]}] (take (int (* downsample-rate (count training-data))) (shuffle training-data))) -(defn update-case-data - "updates the case metadata field of argmap, should be called after evaluation of individuals" - [argmap]) \ No newline at end of file +(defn get-distance-between-cases + "returns the distance between two cases given a list of individual error vectors, and the index these + cases exist in the error vector" + [error-lists case-index-1 case-index-2] + (if (or (< (count (first error-lists)) case-index-1) + (< (count (first error-lists)) case-index-2) + (neg? case-index-1) (neg? case-index-2)) + (count error-lists) ;return the max distance + (let [errors-1 (map #(nth % case-index-1) error-lists) + errors-2 (map #(nth % case-index-2) error-lists)] + ;compute distance between errors-1 and errors-2 + (reduce + (map (fn [e1 e2] (math/abs (- (math/step e1) (math/step e2)))) errors-1 errors-2))))) + +(defn update-at-indices + "merges two vectors at the indices provided by a third vector" + [big-vec small-vec indices] + (->> big-vec + (map-indexed (fn [idx itm] (let [index (.indexOf indices idx)] + (if (not= -1 index) (nth small-vec index) itm)))) + vec)) + +(defn merge-map-lists-at-index + "merges two lists of maps, replacing the maps in the big + list with their corresponding (based on index) maps in the small list" + [big-list small-list] + (map + #(let [corresponding-small (some (fn [c] (when (= (:index %) (:index c)) c)) small-list)] + (if (nil? corresponding-small) % corresponding-small)) + big-list)) + +(defn update-case-distances + "updates the case distance field of training-data list, should be called after evaluation of individuals + evaluated-pop should be a list of individuals that all have the :errors field with a list of this + individuals performance on the each case in the ds-data, in order" + [evaluated-pop ds-data training-data] + (let [ds-indices (map #(:index %) ds-data) errors (map #(:errors %) evaluated-pop)] + (merge-map-lists-at-index training-data + (map-indexed (fn [idx d-case] + (update-in d-case [:distances] #(update-at-indices % + (map (fn [other] (get-distance-between-cases errors idx other)) (range (count ds-indices))) ds-indices))) ds-data)))) diff --git a/src/propeller/gp.cljc b/src/propeller/gp.cljc index 3049c98..d46e19a 100644 --- a/src/propeller/gp.cljc +++ b/src/propeller/gp.cljc @@ -47,7 +47,7 @@ population (mapper (fn [_] {:plushy (genome/make-random-plushy instructions max-initial-plushy-size)}) (range population-size)) - indexed-training-data (downsample/assign-indices-to-data argmap)] + indexed-training-data (downsample/assign-indices-to-data (downsample/initialize-case-distances argmap))] (let [training-data (if (= (:parent-selection argmap) :ds-lexicase) (downsample/select-downsample-random indexed-training-data argmap) indexed-training-data) @@ -55,18 +55,32 @@ (mapper (partial error-function argmap training-data) population)) - best-individual (first evaluated-pop)] + best-individual (first evaluated-pop) + best-individual-passes-ds (and (= (:parent-selection argmap) :ds-lexicase) (<= (:total-error best-individual) solution-error-threshold)) + tot-evaluated-pop (when best-individual-passes-ds ;evaluate the whole pop on all training data + (sort-by :total-error + (mapper + (partial error-function argmap (:training-data argmap)) + population))) + ;;best individual on all training-cases + tot-best-individual (if best-individual-passes-ds (first tot-evaluated-pop) best-individual)] + (prn (first training-data)) (if (:custom-report argmap) ((:custom-report argmap) evaluated-pop generation argmap) (report evaluated-pop generation argmap)) + ;;did the indvidual pass all cases in ds? + (when best-individual-passes-ds + (prn {:semi-success-generation generation})) (cond ;; Success on training cases is verified on testing cases - (<= (:total-error best-individual) solution-error-threshold) + (or (and best-individual-passes-ds (<= (:total-error tot-best-individual) solution-error-threshold)) + (and (not= (:parent-selection argmap) :ds-lexicase) + (<= (:total-error best-individual) solution-error-threshold))) (do (prn {:success-generation generation}) (prn {:total-test-error - (:total-error (error-function argmap (:testing-data argmap) best-individual))}) + (:total-error (error-function argmap (:testing-data argmap) tot-best-individual))}) (when (:simplification? argmap) - (let [simplified-plushy (simplification/auto-simplify-plushy (:plushy best-individual) error-function argmap)] + (let [simplified-plushy (simplification/auto-simplify-plushy (:plushy tot-best-individual) error-function argmap)] (prn {:total-test-error-simplified (:total-error (error-function argmap (:testing-data argmap) (hash-map :plushy simplified-plushy)))})))) ;; (>= generation max-generations) @@ -79,4 +93,4 @@ (first evaluated-pop)) (repeatedly population-size #(variation/new-individual evaluated-pop argmap))) - (update-case-metadata evaluated-pop)))))) + (downsample/update-case-distances evaluated-pop training-data indexed-training-data)))))) diff --git a/src/propeller/selection.cljc b/src/propeller/selection.cljc index 487ecf2..8fbf09f 100755 --- a/src/propeller/selection.cljc +++ b/src/propeller/selection.cljc @@ -26,4 +26,5 @@ [pop argmap] (case (:parent-selection argmap) :tournament (tournament-selection pop argmap) - :lexicase (lexicase-selection pop argmap))) + :lexicase (lexicase-selection pop argmap) + :ds-lexicase (lexicase-selection pop argmap))) diff --git a/test/propeller/utils_test.cljc b/test/propeller/utils_test.cljc index b709921..9029b05 100644 --- a/test/propeller/utils_test.cljc +++ b/test/propeller/utils_test.cljc @@ -86,12 +86,12 @@ (t/deftest assign-indices-to-data-test (t/testing "assign-indices-to-data" (t/testing "should return a map of the same length" - (t/is (= (count (ds/assign-indices-to-data {:training-data (range 10)})) 10)) - (t/is (= (count (ds/assign-indices-to-data {:training-data (range 0)})) 0))) + (t/is (= (count (ds/assign-indices-to-data (range 10))) 10)) + (t/is (= (count (ds/assign-indices-to-data (range 0))) 0))) (t/testing "should return a map where each element has an index key" - (t/is (every? #(:index %) (ds/assign-indices-to-data {:training-data (map #(assoc {} :input %) (range 10))})))) + (t/is (every? #(:index %) (ds/assign-indices-to-data (map #(assoc {} :input %) (range 10)))))) (t/testing "should return distinct indices" - (t/is (= (map #(:index %) (ds/assign-indices-to-data {:training-data (range 10)})) (range 10)))))) + (t/is (= (map #(:index %) (ds/assign-indices-to-data (range 10))) (range 10)))))) (t/deftest select-downsample-random-test (t/testing "select-downsample-random" @@ -109,3 +109,50 @@ (t/testing "should not return more elements than available" (t/is (= (count (ds/select-downsample-random (range 10) {:downsample-rate 2})) 10)) (t/is (= (count (ds/select-downsample-random (range 10) {:downsample-rate 1.5})) 10))))) + +(t/deftest get-distance-between-cases-test + (t/testing "get-distance-between-cases" + (t/testing "should return correct distance" + (t/is (= 3 (ds/get-distance-between-cases '((0 1 1) (0 1 1) (1 0 1)) 0 1)))) + (t/testing "should return 0 for the distance of a case to itself" + (t/is (= 0 (ds/get-distance-between-cases '((0 1 1) (0 1 1) (1 0 1)) 0 0)))) + (t/testing "should work for non binary values (0 is solved)" + (t/is (= 1 (ds/get-distance-between-cases '((0 2 2) (0 2 2) (1 0 50)) 1 2)))) + (t/testing "should return the max distance if one of the cases does not exist" + (t/is (= 3 (ds/get-distance-between-cases '((0 1 1) (0 1 1) (1 0 1)) 0 4)))))) + +(t/deftest merge-map-lists-at-index-test + (t/testing "merge-map-lists-at-index" + (t/testing "works properly" + (t/is (= '({:index 0 :a 3 :b 2} {:index 1 :a 2 :b 3}) (ds/merge-map-lists-at-index '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) '({:index 1 :a 2 :b 3}))))) + (t/testing "doesn't change big list if no indices match" + (t/is (= '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) (ds/merge-map-lists-at-index '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) '({:index 3 :a 2 :b 3}))))) + (t/testing "doesn't fail on empty list" + (t/is (= '() (ds/merge-map-lists-at-index '() '())))) + (t/testing "shouldn't fail merging non-empty with empty" + (t/is (= '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) (ds/merge-map-lists-at-index '({:index 0 :a 3 :b 2} {:index 1 :a 1 :b 2}) '())))))) + +(t/deftest update-at-indices-test + (t/testing "update-at-indices" + (t/testing "should update at correct indices" + (t/is (= (ds/update-at-indices [1 2 3 4] [5] [0]) [5 2 3 4])) + (t/is (= (ds/update-at-indices [1 2 3 4] [5] [0]) [5 2 3 4]))) + (t/testing "should update nothing if index list is empty" + (t/is (= (ds/update-at-indices [6 5 4 0 0] [] []) [6 5 4 0 0]))) + (t/testing "should update nothing if index list is out of bounds" + (t/is (= (ds/update-at-indices [6 5 4 0 0] [4 5 1] [-1 5 6]) [6 5 4 0 0]))) + (t/testing "should update only when indices are available (length mismatch)" + (t/is (= (ds/update-at-indices [6 5 4 0 0] [1 2 3 4] [0 1]) [1 2 4 0 0]))) + (t/testing "should not care about index order" + (t/is (= (ds/update-at-indices [6 5 4 0 0] [2 1] [1 0]) [1 2 4 0 0]))) + (t/testing "should work when input is a list" + (t/is (= (ds/update-at-indices '(6 5 4 0 0) '(2 1) '(1 0)) [1 2 4 0 0]))))) + +(t/deftest update-case-distances-test + (t/testing "update-case-distances" + (t/testing "should ..." + (t/is (= (ds/update-case-distances '({:errors (0 0)} {:errors (0 0)}) + '({:index 3 :distances [2 2 2 2 2]} {:index 4 :distances [2 2 2 2 2]}) + '({:index 0 :distances [2 2 2 2 2]} {:index 1 :distances [2 2 2 2 2]} {:index 2 :distances [2 2 2 2 2]} {:index 3 :distances [2 2 2 2 2]} {:index 4 :distances [2 2 2 2 2]})) + '({:index 0 :distances [2 2 2 2 2]} {:index 1 :distances [2 2 2 2 2]} {:index 2 :distances [2 2 2 2 2]} + {:index 3 :distances [2 2 2 0 0]} {:index 4 :distances [2 2 2 0 0]})))))) \ No newline at end of file