diff --git a/src/propeller/downsample.cljc b/src/propeller/downsample.cljc index 0e0f7d7..8037515 100644 --- a/src/propeller/downsample.cljc +++ b/src/propeller/downsample.cljc @@ -45,15 +45,14 @@ (defn select-downsample-maxmin "selects a downsample that has it's cases maximally far away by sequentially adding cases to the downsample that have their closest case maximally far away" - [training-data {:keys [downsample-rate case-t-size]}] + [training-data {:keys [downsample-rate]}] (let [shuffled-cases (shuffle training-data) goal-size (int (* downsample-rate (count training-data)))] (loop [new-downsample (conj [] (first shuffled-cases)) cases-to-pick-from (rest shuffled-cases)] (if (>= (count new-downsample) goal-size) new-downsample - (let [tournament (take case-t-size cases-to-pick-from) - rest-of-cases (drop case-t-size cases-to-pick-from) + (let [tournament cases-to-pick-from min-case-distances (metrics/min-of-colls (map (fn [distance-list] (utils/filter-by-index distance-list (map #(:index %) tournament))) @@ -64,35 +63,31 @@ (prn {:cases-in-ds (map #(:input1 %) new-downsample) :cases-in-tourn (map #(:input1 %) tournament)})) (prn {:min-case-distances min-case-distances :selected-case-index selected-case-index}) (recur (conj new-downsample (nth tournament selected-case-index)) - (shuffle (concat (utils/drop-nth selected-case-index tournament) - rest-of-cases)))))))) + (shuffle (utils/drop-nth selected-case-index tournament)))))))) (defn select-downsample-maxmin-adaptive "selects a downsample that has it's cases maximally far away by sequentially adding cases to the downsample that have their closest case maximally far away automatically stops when the maximum minimum distance is below delta" -[training-data {:keys [case-t-size case-delta]}] +[training-data {:keys [case-delta]}] (let [shuffled-cases (shuffle training-data)] (loop [new-downsample (conj [] (first shuffled-cases)) - cases-to-pick-from (rest shuffled-cases) - end? false] - (if (or end? (zero? (count cases-to-pick-from))) - new-downsample - (let [tournament (take case-t-size cases-to-pick-from) - rest-of-cases (drop case-t-size cases-to-pick-from) + cases-to-pick-from (rest shuffled-cases)] + (let [tournament cases-to-pick-from min-case-distances (metrics/min-of-colls (map (fn [distance-list] (utils/filter-by-index distance-list (map #(:index %) tournament))) (map #(:distances %) new-downsample))) selected-case-index (metrics/argmax min-case-distances)] - (if (sequential? (:input1 (first new-downsample))) - (prn {:cases-in-ds (map #(first (:input1 %)) new-downsample) :cases-in-tourn (map #(first (:input1 %)) tournament)}) - (prn {:cases-in-ds (map #(:input1 %) new-downsample) :cases-in-tourn (map #(:input1 %) tournament)})) - (prn {:min-case-distances min-case-distances :selected-case-index selected-case-index}) - (recur (conj new-downsample (nth tournament selected-case-index)) - (shuffle (concat (utils/drop-nth selected-case-index tournament) - rest-of-cases)) - (<= (apply max min-case-distances) case-delta))))))) + (if (or (<= (apply max min-case-distances) case-delta) (zero? (count cases-to-pick-from))) + new-downsample + (do + (if (sequential? (:input1 (first new-downsample))) + (prn {:cases-in-ds (map #(first (:input1 %)) new-downsample) :cases-in-tourn (map #(first (:input1 %)) tournament)}) + (prn {:cases-in-ds (map #(:input1 %) new-downsample) :cases-in-tourn (map #(:input1 %) tournament)})) + (prn {:min-case-distances min-case-distances :selected-case-index selected-case-index}) + (recur (conj new-downsample (nth tournament selected-case-index)) + (shuffle (utils/drop-nth selected-case-index tournament))))))))) (defn get-distance-between-cases "returns the distance between two cases given a list of individual error vectors, and the index these diff --git a/test/propeller/utils_test.cljc b/test/propeller/utils_test.cljc index 4bef17e..d909d52 100644 --- a/test/propeller/utils_test.cljc +++ b/test/propeller/utils_test.cljc @@ -166,10 +166,33 @@ {:input1 [2] :output1 [12] :index 2 :distances [0 5 0 0 0]} {:input1 [3] :output1 [13] :index 3 :distances [0 5 0 0 0]} {:input1 [4] :output1 [14] :index 4 :distances [0 5 0 0 0]}) - {:downsample-rate 0.4 :case-t-size 5})] + {:downsample-rate 0.4})] (prn {:selected selected}) (t/is (or (= (:index (first selected)) 1) (= (:index (second selected)) 1)))))) +(t/deftest case-maxmin-adaptive + (t/testing "case-maxmin-adaptive selects correct downsample simple" + (let [selected (ds/select-downsample-maxmin-adaptive + '({:input1 [0] :output1 [10] :index 0 :distances [0 5 0 0 0]} + {:input1 [1] :output1 [11] :index 1 :distances [5 0 5 5 5]} + {:input1 [2] :output1 [12] :index 2 :distances [0 5 0 0 0]} + {:input1 [3] :output1 [13] :index 3 :distances [0 5 0 0 0]} + {:input1 [4] :output1 [14] :index 4 :distances [0 5 0 0 0]}) + {:case-delta 0})] + (prn {:selected selected}) + (t/is (or (= (:index (first selected)) 1) (= (:index (second selected)) 1))) + (t/is (= 2 (count selected))))) + (t/testing "case-maxmin-adaptive selects correct downsample when all identical" + (let [selected (ds/select-downsample-maxmin-adaptive + '({:input1 [0] :output1 [10] :index 0 :distances [0 0 0 0 0]} + {:input1 [1] :output1 [11] :index 1 :distances [0 0 0 0 0]} + {:input1 [2] :output1 [12] :index 2 :distances [0 0 0 0 0]} + {:input1 [3] :output1 [13] :index 3 :distances [0 0 0 0 0]} + {:input1 [4] :output1 [14] :index 4 :distances [0 0 0 0 0]}) + {:case-delta 0})] + (prn {:selected selected}) + (t/is (= 1 (count selected)))))) + (t/deftest hyperselection-test (let [parents1 '({:blah 3 :index 1} {:blah 3 :index 1}