[1113] | 1 | from typing import List, Callable, Union |
---|
| 2 | |
---|
| 3 | from evolalg.base.step import Step |
---|
| 4 | import pickle |
---|
| 5 | |
---|
| 6 | from evolalg.base.union_step import UnionStep |
---|
| 7 | from evolalg.selection.selection import Selection |
---|
| 8 | from evolalg.utils.stable_generation import StableGeneration |
---|
| 9 | |
---|
| 10 | |
---|
| 11 | class Experiment: |
---|
| 12 | def __init__(self, init_population: List[Callable], |
---|
| 13 | selection: Selection, |
---|
| 14 | new_generation_steps: List[Union[Callable, Step]], |
---|
| 15 | generation_modification: List[Union[Callable, Step]], |
---|
| 16 | end_steps: List[Union[Callable, Step]], |
---|
| 17 | population_size, |
---|
| 18 | checkpoint_path=None, checkpoint_interval=None): |
---|
| 19 | |
---|
| 20 | self.init_population = init_population |
---|
| 21 | self.step = StableGeneration( |
---|
| 22 | selection=selection, |
---|
| 23 | steps=new_generation_steps, |
---|
| 24 | population_size=population_size) |
---|
| 25 | self.generation_modification = UnionStep(generation_modification) |
---|
| 26 | |
---|
| 27 | self.end_steps = UnionStep(end_steps) |
---|
| 28 | |
---|
| 29 | self.checkpoint_path = checkpoint_path |
---|
| 30 | self.checkpoint_interval = checkpoint_interval |
---|
| 31 | self.generation = 0 |
---|
| 32 | self.population = None |
---|
| 33 | |
---|
| 34 | def init(self): |
---|
| 35 | self.generation = 0 |
---|
| 36 | for s in self.init_population: |
---|
| 37 | if isinstance(s, Step): |
---|
| 38 | s.init() |
---|
| 39 | |
---|
| 40 | self.step.init() |
---|
| 41 | self.generation_modification.init() |
---|
| 42 | self.end_steps.init() |
---|
| 43 | |
---|
| 44 | for s in self.init_population: |
---|
| 45 | self.population = s(self.population) |
---|
| 46 | |
---|
| 47 | def run(self, num_generations): |
---|
| 48 | for i in range(self.generation + 1, num_generations + 1): |
---|
| 49 | self.generation = i |
---|
| 50 | self.population = self.step(self.population) |
---|
| 51 | self.population = self.generation_modification(self.population) |
---|
| 52 | |
---|
| 53 | if (self.checkpoint_path is not None |
---|
| 54 | and self.checkpoint_interval is not None |
---|
| 55 | and i % self.checkpoint_interval == 0): |
---|
| 56 | with open(self.checkpoint_path, "wb") as file: |
---|
| 57 | pickle.dump(self, file) |
---|
| 58 | |
---|
| 59 | self.population = self.end_steps(self.population) |
---|
| 60 | |
---|
| 61 | @staticmethod |
---|
| 62 | def restore(path): |
---|
| 63 | with open(path) as file: |
---|
| 64 | res = pickle.load(file) |
---|
| 65 | return res |
---|