simplification, need to test
This commit is contained in:
parent
7e7d5523e9
commit
35ad71ffc3
@ -19,7 +19,7 @@ pub struct PushArgs {
|
|||||||
pub dont_end: bool, // If true, keep running until limit regardless of success
|
pub dont_end: bool, // If true, keep running until limit regardless of success
|
||||||
// pub downsample: bool, // Whether or not to downsample. TODO later with all the related args
|
// pub downsample: bool, // Whether or not to downsample. TODO later with all the related args
|
||||||
pub elitism: bool, // Whether to always add the best individual to next generation
|
pub elitism: bool, // Whether to always add the best individual to next generation
|
||||||
pub error_function: fn(&PushArgs, DataFrame, Vec<Gene>) -> Series, // The error function
|
pub error_function: fn(&PushArgs, &DataFrame, Vec<Gene>) -> Vec<Decimal>, // The error function
|
||||||
pub instructions: Vec<Gene>, // Instructions to use in a run
|
pub instructions: Vec<Gene>, // Instructions to use in a run
|
||||||
pub max_init_plushy_size: usize, // max initial plushy size
|
pub max_init_plushy_size: usize, // max initial plushy size
|
||||||
pub max_generations: usize, // Max amount of generations
|
pub max_generations: usize, // Max amount of generations
|
||||||
@ -29,6 +29,8 @@ pub struct PushArgs {
|
|||||||
pub use_simplification: bool, // Whether to use simplification at end of run
|
pub use_simplification: bool, // Whether to use simplification at end of run
|
||||||
pub simplification_k: usize, // Max amt of genes to attempt removal during one round of simplification process
|
pub simplification_k: usize, // Max amt of genes to attempt removal during one round of simplification process
|
||||||
pub simplification_steps: usize, // How many attempts to find simplified genomes
|
pub simplification_steps: usize, // How many attempts to find simplified genomes
|
||||||
|
pub simplification_verbose: bool, // Whether to send extra messages about simplification or not
|
||||||
|
pub solution_error_threshold: Decimal, // Max total error for solutions
|
||||||
pub use_single_thread: bool, // if true, only single threaded
|
pub use_single_thread: bool, // if true, only single threaded
|
||||||
pub step_limit: usize, // Amount of steps a push interpreter can run for
|
pub step_limit: usize, // Amount of steps a push interpreter can run for
|
||||||
pub testing_data: DataFrame, // The testing data, must be formatted the same as training data
|
pub testing_data: DataFrame, // The testing data, must be formatted the same as training data
|
||||||
|
@ -1,8 +1,23 @@
|
|||||||
|
use args::PushArgs;
|
||||||
|
|
||||||
pub mod args;
|
pub mod args;
|
||||||
pub mod genome;
|
pub mod genome;
|
||||||
pub mod individual;
|
pub mod individual;
|
||||||
pub mod selection;
|
pub mod selection;
|
||||||
|
pub mod simplification;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
pub mod variation;
|
pub mod variation;
|
||||||
|
|
||||||
// pub fn gp_loop
|
pub fn gp_loop(push_args: PushArgs) -> bool {
|
||||||
|
let pop_size = push_args.pop_size;
|
||||||
|
let max_gens = push_args.max_generations;
|
||||||
|
let error_func = push_args.error_function;
|
||||||
|
let solution_error_threshold = push_args.solution_error_threshold;
|
||||||
|
let dont_end = push_args.dont_end;
|
||||||
|
let elitism = push_args.elitism;
|
||||||
|
let training_data = push_args.training_data;
|
||||||
|
let testing_data = push_args.testing_data;
|
||||||
|
let simplification = push_args.use_simplification;
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
76
src/gp/simplification.rs
Normal file
76
src/gp/simplification.rs
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
use super::args::PushArgs;
|
||||||
|
use crate::push::state::Gene;
|
||||||
|
use polars::prelude::*;
|
||||||
|
use rand::Rng;
|
||||||
|
use rand::prelude::SliceRandom;
|
||||||
|
use rand::rng;
|
||||||
|
use rust_decimal::Decimal;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
/// Takes k random indices from the given range
|
||||||
|
fn choose_random_k(k: usize, indices_count: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||||
|
let mut indices: Vec<usize> = (0..indices_count).collect();
|
||||||
|
indices.shuffle(rng);
|
||||||
|
indices.truncate(k);
|
||||||
|
indices
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deletes the values at the given set of indices
|
||||||
|
fn delete_at_indices<T: Clone>(indices: &[usize], plushy: &[T]) -> Vec<T> {
|
||||||
|
let indices_set: HashSet<usize> = indices.iter().cloned().collect();
|
||||||
|
plushy
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(i, item)| {
|
||||||
|
if !indices_set.contains(&i) {
|
||||||
|
Some(item.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deletes k random instructions from the plushy
|
||||||
|
fn delete_k_random<T: Clone>(k: usize, plushy: &[T], rng: &mut impl Rng) -> Vec<T> {
|
||||||
|
let actual_k = std::cmp::min(k, plushy.len());
|
||||||
|
if actual_k == 0 {
|
||||||
|
return plushy.to_vec();
|
||||||
|
}
|
||||||
|
let indices = choose_random_k(actual_k, plushy.len(), rng);
|
||||||
|
delete_at_indices(&indices, plushy)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn auto_simplify_plushy<F>(plushy: Vec<Gene>, error_func: F, push_args: PushArgs) -> Vec<Gene>
|
||||||
|
where
|
||||||
|
F: Fn(&PushArgs, &DataFrame, Vec<Gene>) -> Vec<Decimal>,
|
||||||
|
{
|
||||||
|
if push_args.simplification_verbose {
|
||||||
|
println!(
|
||||||
|
"{{ start_plushy_length: {}, k: {} }}",
|
||||||
|
plushy.len(),
|
||||||
|
push_args.simplification_k
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let initial_errors = error_func(&push_args, &push_args.training_data, plushy.clone());
|
||||||
|
let mut step = 0;
|
||||||
|
let mut curr_plushy = plushy;
|
||||||
|
|
||||||
|
while step < push_args.simplification_steps {
|
||||||
|
let mut rng = rng();
|
||||||
|
let random_k = rng.random_range(1..=push_args.simplification_k);
|
||||||
|
|
||||||
|
let new_plushy = delete_k_random(random_k, &curr_plushy, &mut rng);
|
||||||
|
let new_plushy_errors =
|
||||||
|
error_func(&push_args, &push_args.training_data, new_plushy.clone());
|
||||||
|
|
||||||
|
if new_plushy_errors.iter().sum::<Decimal>() <= initial_errors.iter().sum() {
|
||||||
|
curr_plushy = new_plushy;
|
||||||
|
}
|
||||||
|
|
||||||
|
step += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
curr_plushy
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user