diff --git a/src/propeller/downsample.cljc b/src/propeller/downsample.cljc index 8c85ab8..b1e8d02 100644 --- a/src/propeller/downsample.cljc +++ b/src/propeller/downsample.cljc @@ -10,7 +10,7 @@ (defn select-downsample-random "Selects a downsample from the training cases and returns it" - [{:keys [downsample-rate training-data]}] + [training-data {:keys [downsample-rate]}] (take (int (* downsample-rate (count training-data))) (shuffle training-data))) (defn update-case-data diff --git a/src/propeller/gp.cljc b/src/propeller/gp.cljc index 3c89f00..4af4474 100644 --- a/src/propeller/gp.cljc +++ b/src/propeller/gp.cljc @@ -4,6 +4,7 @@ [propeller.genome :as genome] [propeller.simplification :as simplification] [propeller.variation :as variation] + [propeller.downsample :as downsample] [propeller.push.instructions.bool] [propeller.push.instructions.character] [propeller.push.instructions.code] @@ -44,12 +45,16 @@ ;; (loop [generation 0 population (mapper - (fn [_] {:plushy (genome/make-random-plushy instructions max-initial-plushy-size)}) - (range population-size))] - (let [evaluated-pop (sort-by :total-error + (fn [_] {:plushy (genome/make-random-plushy instructions max-initial-plushy-size)}) + (range population-size)) + indexed-training-data (downsample/assign-indices-to-data argmap)] + (let [training-data (if (= (:parent-selection argmap) :ds-lexicase) + (downsample/select-downsample-random indexed-training-data argmap) + indexed-training-data) + evaluated-pop (sort-by :total-error (mapper - (partial error-function argmap (:training-data argmap)) - population)) + (partial error-function argmap training-data) + population)) best-individual (first evaluated-pop)] (if (:custom-report argmap) ((:custom-report argmap) evaluated-pop generation argmap) @@ -61,7 +66,7 @@ (prn {:total-test-error (:total-error (error-function argmap (:testing-data argmap) best-individual))}) (if (:simplification? argmap) - (let [simplified-plushy (simplification/auto-simplify-plushy argmap (:plushy best-individual) (:simplification-steps argmap) error-function (:training-data argmap) (:simplification-k argmap) (:simplification-verbose? argmap))] + (let [simplified-plushy (simplification/auto-simplify-plushy argmap (:plushy best-individual) (:simplification-steps argmap) error-function (:testing-data argmap) (:simplification-k argmap) (:simplification-verbose? argmap))] (prn {:total-test-error-simplified (:total-error (error-function argmap (:testing-data argmap) (hash-map :plushy simplified-plushy)))})))) ;; (>= generation max-generations) @@ -73,4 +78,5 @@ #(variation/new-individual evaluated-pop argmap)) (first evaluated-pop)) (repeatedly population-size - #(variation/new-individual evaluated-pop argmap)))))))) + #(variation/new-individual evaluated-pop argmap))) + (update-case-metadata evaluated-pop)))))) diff --git a/test/propeller/utils_test.cljc b/test/propeller/utils_test.cljc index 1212e24..30ba7e5 100644 --- a/test/propeller/utils_test.cljc +++ b/test/propeller/utils_test.cljc @@ -96,16 +96,16 @@ (t/deftest select-downsample-random-test (t/testing "select-downsample-random" (t/testing "should select the correct amount of elements" - (t/is (= (count (ds/select-downsample-random {:training-data (range 10) :downsample-rate 0.1})) 1)) - (t/is (= (count (ds/select-downsample-random {:training-data (range 10) :downsample-rate 0.2})) 2)) - (t/is (= (count (ds/select-downsample-random {:training-data (range 10) :downsample-rate 0.5})) 5))) + (t/is (= (count (ds/select-downsample-random (range 10) {:downsample-rate 0.1})) 1)) + (t/is (= (count (ds/select-downsample-random (range 10) {:downsample-rate 0.2})) 2)) + (t/is (= (count (ds/select-downsample-random (range 10) {:downsample-rate 0.5})) 5))) (t/testing "should not return duplicate items (when called with set of numbers)" - (t/is (= (count (set (ds/select-downsample-random {:training-data (range 10) :downsample-rate 0.1}))) 1)) - (t/is (= (count (set (ds/select-downsample-random {:training-data (range 10) :downsample-rate 0.2}))) 2)) - (t/is (= (count (set (ds/select-downsample-random {:training-data (range 10) :downsample-rate 0.5}))) 5))) + (t/is (= (count (set (ds/select-downsample-random (range 10) {:downsample-rate 0.1}))) 1)) + (t/is (= (count (set (ds/select-downsample-random (range 10) {:downsample-rate 0.2}))) 2)) + (t/is (= (count (set (ds/select-downsample-random (range 10) {:downsample-rate 0.5}))) 5))) (t/testing "should round down the number of elements selected if not whole" - (t/is (= (count (ds/select-downsample-random {:training-data (range 3) :downsample-rate 0.5})) 1)) - (t/is (= (count (ds/select-downsample-random {:training-data (range 1) :downsample-rate 0.5})) 0))) + (t/is (= (count (ds/select-downsample-random (range 3) {:downsample-rate 0.5})) 1)) + (t/is (= (count (ds/select-downsample-random (range 1) {:downsample-rate 0.5})) 0))) (t/testing "should not return more elements than available" - (t/is (= (count (ds/select-downsample-random {:training-data (range 10) :downsample-rate 2})) 10)) - (t/is (= (count (ds/select-downsample-random {:training-data (range 10) :downsample-rate 1.5})) 10))))) + (t/is (= (count (ds/select-downsample-random (range 10) {:downsample-rate 2})) 10)) + (t/is (= (count (ds/select-downsample-random (range 10) {:downsample-rate 1.5})) 10)))))