1 | import os |
---|
2 | from typing import List, Callable, Union |
---|
3 | |
---|
4 | from evolalg_steps.base.step import Step |
---|
5 | import pickle |
---|
6 | import time |
---|
7 | |
---|
8 | from evolalg_steps.base.union_step import UnionStep |
---|
9 | from evolalg_steps.selection.selection import Selection |
---|
10 | from evolalg_steps.utils.stable_generation import StableGeneration |
---|
11 | import logging |
---|
12 | |
---|
13 | class Experiment: |
---|
14 | def __init__(self, init_population: List[Callable], |
---|
15 | selection: Selection, |
---|
16 | new_generation_steps: List[Union[Callable, Step]], |
---|
17 | generation_modification: List[Union[Callable, Step]], |
---|
18 | end_steps: List[Union[Callable, Step]], |
---|
19 | population_size, |
---|
20 | checkpoint_path=None, checkpoint_interval=None): |
---|
21 | |
---|
22 | self.init_population = init_population |
---|
23 | self.running_time = 0 |
---|
24 | self.step = StableGeneration( |
---|
25 | selection=selection, |
---|
26 | steps=new_generation_steps, |
---|
27 | population_size=population_size) |
---|
28 | self.generation_modification = UnionStep(generation_modification) |
---|
29 | |
---|
30 | self.end_steps = UnionStep(end_steps) |
---|
31 | |
---|
32 | self.checkpoint_path = checkpoint_path |
---|
33 | self.checkpoint_interval = checkpoint_interval |
---|
34 | self.generation = 0 |
---|
35 | self.population = None |
---|
36 | |
---|
37 | def init(self): |
---|
38 | self.generation = 0 |
---|
39 | for s in self.init_population: |
---|
40 | if isinstance(s, Step): |
---|
41 | s.init() |
---|
42 | |
---|
43 | self.step.init() |
---|
44 | self.generation_modification.init() |
---|
45 | self.end_steps.init() |
---|
46 | self.population = [] |
---|
47 | for s in self.init_population: |
---|
48 | self.population = s(self.population) |
---|
49 | |
---|
50 | def run(self, num_generations): |
---|
51 | for i in range(self.generation + 1, num_generations + 1): |
---|
52 | start_time = time.time() |
---|
53 | self.generation = i |
---|
54 | self.population = self.step(self.population) |
---|
55 | self.population = self.generation_modification(self.population) |
---|
56 | |
---|
57 | self.running_time += time.time() - start_time |
---|
58 | if (self.checkpoint_path is not None |
---|
59 | and self.checkpoint_interval is not None |
---|
60 | and i % self.checkpoint_interval == 0): |
---|
61 | self.save_checkpoint() |
---|
62 | |
---|
63 | self.population = self.end_steps(self.population) |
---|
64 | |
---|
65 | def save_checkpoint(self): |
---|
66 | tmp_filepath = self.checkpoint_path+"_tmp" |
---|
67 | try: |
---|
68 | with open(tmp_filepath, "wb") as file: |
---|
69 | pickle.dump(self, file) |
---|
70 | os.replace(tmp_filepath, self.checkpoint_path) # ensures the new file was first saved OK (e.g. enough free space on device), then replace |
---|
71 | except Exception as ex: |
---|
72 | raise RuntimeError("Failed to save checkpoint '%s' (because: %s). This does not prevent the experiment from continuing, but let's stop here to fix the problem with saving checkpoints." % (tmp_filepath, ex)) |
---|
73 | |
---|
74 | |
---|
75 | @staticmethod |
---|
76 | def restore(path): |
---|
77 | with open(path) as file: |
---|
78 | res = pickle.load(file) |
---|
79 | return res |
---|