1 | from typing import List, Callable, Union |
---|
2 | |
---|
3 | from evolalg.base.step import Step |
---|
4 | import pickle |
---|
5 | |
---|
6 | from evolalg.base.union_step import UnionStep |
---|
7 | from evolalg.selection.selection import Selection |
---|
8 | from evolalg.utils.stable_generation import StableGeneration |
---|
9 | |
---|
10 | |
---|
11 | class Experiment: |
---|
12 | def __init__(self, init_population: List[Callable], |
---|
13 | selection: Selection, |
---|
14 | new_generation_steps: List[Union[Callable, Step]], |
---|
15 | generation_modification: List[Union[Callable, Step]], |
---|
16 | end_steps: List[Union[Callable, Step]], |
---|
17 | population_size, |
---|
18 | checkpoint_path=None, checkpoint_interval=None): |
---|
19 | |
---|
20 | self.init_population = init_population |
---|
21 | self.step = StableGeneration( |
---|
22 | selection=selection, |
---|
23 | steps=new_generation_steps, |
---|
24 | population_size=population_size) |
---|
25 | self.generation_modification = UnionStep(generation_modification) |
---|
26 | |
---|
27 | self.end_steps = UnionStep(end_steps) |
---|
28 | |
---|
29 | self.checkpoint_path = checkpoint_path |
---|
30 | self.checkpoint_interval = checkpoint_interval |
---|
31 | self.generation = 0 |
---|
32 | self.population = None |
---|
33 | |
---|
34 | def init(self): |
---|
35 | self.generation = 0 |
---|
36 | for s in self.init_population: |
---|
37 | if isinstance(s, Step): |
---|
38 | s.init() |
---|
39 | |
---|
40 | self.step.init() |
---|
41 | self.generation_modification.init() |
---|
42 | self.end_steps.init() |
---|
43 | |
---|
44 | for s in self.init_population: |
---|
45 | self.population = s(self.population) |
---|
46 | |
---|
47 | def run(self, num_generations): |
---|
48 | for i in range(self.generation + 1, num_generations + 1): |
---|
49 | self.generation = i |
---|
50 | self.population = self.step(self.population) |
---|
51 | self.population = self.generation_modification(self.population) |
---|
52 | |
---|
53 | if (self.checkpoint_path is not None |
---|
54 | and self.checkpoint_interval is not None |
---|
55 | and i % self.checkpoint_interval == 0): |
---|
56 | with open(self.checkpoint_path, "wb") as file: |
---|
57 | pickle.dump(self, file) |
---|
58 | |
---|
59 | self.population = self.end_steps(self.population) |
---|
60 | |
---|
61 | @staticmethod |
---|
62 | def restore(path): |
---|
63 | with open(path) as file: |
---|
64 | res = pickle.load(file) |
---|
65 | return res |
---|