From 35ad71ffc3e2d19203386b8401a8db33b2ec31ec Mon Sep 17 00:00:00 2001 From: Rowan Torbitzky-Lane Date: Tue, 29 Apr 2025 17:47:01 -0500 Subject: [PATCH] simplification, need to test --- src/gp/args.rs | 4 ++- src/gp/mod.rs | 17 ++++++++- src/gp/simplification.rs | 76 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 src/gp/simplification.rs diff --git a/src/gp/args.rs b/src/gp/args.rs index fa936b8..479aa1e 100644 --- a/src/gp/args.rs +++ b/src/gp/args.rs @@ -19,7 +19,7 @@ pub struct PushArgs { pub dont_end: bool, // If true, keep running until limit regardless of success // pub downsample: bool, // Whether or not to downsample. TODO later with all the related args pub elitism: bool, // Whether to always add the best individual to next generation - pub error_function: fn(&PushArgs, DataFrame, Vec) -> Series, // The error function + pub error_function: fn(&PushArgs, &DataFrame, Vec) -> Vec, // The error function pub instructions: Vec, // Instructions to use in a run pub max_init_plushy_size: usize, // max initial plushy size pub max_generations: usize, // Max amount of generations @@ -29,6 +29,8 @@ pub struct PushArgs { pub use_simplification: bool, // Whether to use simplification at end of run pub simplification_k: usize, // Max amt of genes to attempt removal during one round of simplification process pub simplification_steps: usize, // How many attempts to find simplified genomes + pub simplification_verbose: bool, // Whether to send extra messages about simplification or not + pub solution_error_threshold: Decimal, // Max total error for solutions pub use_single_thread: bool, // if true, only single threaded pub step_limit: usize, // Amount of steps a push interpreter can run for pub testing_data: DataFrame, // The testing data, must be formatted the same as training data diff --git a/src/gp/mod.rs b/src/gp/mod.rs index c4ea562..6a6616e 100644 --- a/src/gp/mod.rs +++ b/src/gp/mod.rs @@ -1,8 +1,23 @@ +use args::PushArgs; + pub mod args; pub mod genome; pub mod individual; pub mod selection; +pub mod simplification; pub mod utils; pub mod variation; -// pub fn gp_loop +pub fn gp_loop(push_args: PushArgs) -> bool { + let pop_size = push_args.pop_size; + let max_gens = push_args.max_generations; + let error_func = push_args.error_function; + let solution_error_threshold = push_args.solution_error_threshold; + let dont_end = push_args.dont_end; + let elitism = push_args.elitism; + let training_data = push_args.training_data; + let testing_data = push_args.testing_data; + let simplification = push_args.use_simplification; + + true +} diff --git a/src/gp/simplification.rs b/src/gp/simplification.rs new file mode 100644 index 0000000..5037fc6 --- /dev/null +++ b/src/gp/simplification.rs @@ -0,0 +1,76 @@ +use super::args::PushArgs; +use crate::push::state::Gene; +use polars::prelude::*; +use rand::Rng; +use rand::prelude::SliceRandom; +use rand::rng; +use rust_decimal::Decimal; +use std::collections::HashSet; + +/// Takes k random indices from the given range +fn choose_random_k(k: usize, indices_count: usize, rng: &mut impl Rng) -> Vec { + let mut indices: Vec = (0..indices_count).collect(); + indices.shuffle(rng); + indices.truncate(k); + indices +} + +/// Deletes the values at the given set of indices +fn delete_at_indices(indices: &[usize], plushy: &[T]) -> Vec { + let indices_set: HashSet = indices.iter().cloned().collect(); + plushy + .iter() + .enumerate() + .filter_map(|(i, item)| { + if !indices_set.contains(&i) { + Some(item.clone()) + } else { + None + } + }) + .collect() +} + +/// Deletes k random instructions from the plushy +fn delete_k_random(k: usize, plushy: &[T], rng: &mut impl Rng) -> Vec { + let actual_k = std::cmp::min(k, plushy.len()); + if actual_k == 0 { + return plushy.to_vec(); + } + let indices = choose_random_k(actual_k, plushy.len(), rng); + delete_at_indices(&indices, plushy) +} + +pub fn auto_simplify_plushy(plushy: Vec, error_func: F, push_args: PushArgs) -> Vec +where + F: Fn(&PushArgs, &DataFrame, Vec) -> Vec, +{ + if push_args.simplification_verbose { + println!( + "{{ start_plushy_length: {}, k: {} }}", + plushy.len(), + push_args.simplification_k + ); + } + + let initial_errors = error_func(&push_args, &push_args.training_data, plushy.clone()); + let mut step = 0; + let mut curr_plushy = plushy; + + while step < push_args.simplification_steps { + let mut rng = rng(); + let random_k = rng.random_range(1..=push_args.simplification_k); + + let new_plushy = delete_k_random(random_k, &curr_plushy, &mut rng); + let new_plushy_errors = + error_func(&push_args, &push_args.training_data, new_plushy.clone()); + + if new_plushy_errors.iter().sum::() <= initial_errors.iter().sum() { + curr_plushy = new_plushy; + } + + step += 1; + } + + curr_plushy +}