[1127] | 1 | import os |
---|
[1113] | 2 | from typing import List, Callable, Union |
---|
| 3 | |
---|
[1185] | 4 | from evolalg_steps.base.step import Step |
---|
[1113] | 5 | import pickle |
---|
[1127] | 6 | import time |
---|
[1113] | 7 | |
---|
[1185] | 8 | from evolalg_steps.base.union_step import UnionStep |
---|
| 9 | from evolalg_steps.selection.selection import Selection |
---|
| 10 | from evolalg_steps.utils.stable_generation import StableGeneration |
---|
[1139] | 11 | import logging |
---|
[1113] | 12 | |
---|
| 13 | class Experiment: |
---|
| 14 | def __init__(self, init_population: List[Callable], |
---|
| 15 | selection: Selection, |
---|
| 16 | new_generation_steps: List[Union[Callable, Step]], |
---|
| 17 | generation_modification: List[Union[Callable, Step]], |
---|
| 18 | end_steps: List[Union[Callable, Step]], |
---|
| 19 | population_size, |
---|
| 20 | checkpoint_path=None, checkpoint_interval=None): |
---|
| 21 | |
---|
| 22 | self.init_population = init_population |
---|
[1127] | 23 | self.running_time = 0 |
---|
[1113] | 24 | self.step = StableGeneration( |
---|
| 25 | selection=selection, |
---|
| 26 | steps=new_generation_steps, |
---|
| 27 | population_size=population_size) |
---|
| 28 | self.generation_modification = UnionStep(generation_modification) |
---|
| 29 | |
---|
| 30 | self.end_steps = UnionStep(end_steps) |
---|
| 31 | |
---|
| 32 | self.checkpoint_path = checkpoint_path |
---|
| 33 | self.checkpoint_interval = checkpoint_interval |
---|
| 34 | self.generation = 0 |
---|
| 35 | self.population = None |
---|
| 36 | |
---|
| 37 | def init(self): |
---|
| 38 | self.generation = 0 |
---|
| 39 | for s in self.init_population: |
---|
| 40 | if isinstance(s, Step): |
---|
| 41 | s.init() |
---|
| 42 | |
---|
| 43 | self.step.init() |
---|
| 44 | self.generation_modification.init() |
---|
| 45 | self.end_steps.init() |
---|
[1139] | 46 | self.population = [] |
---|
[1113] | 47 | for s in self.init_population: |
---|
| 48 | self.population = s(self.population) |
---|
| 49 | |
---|
| 50 | def run(self, num_generations): |
---|
| 51 | for i in range(self.generation + 1, num_generations + 1): |
---|
[1127] | 52 | start_time = time.time() |
---|
[1113] | 53 | self.generation = i |
---|
| 54 | self.population = self.step(self.population) |
---|
| 55 | self.population = self.generation_modification(self.population) |
---|
| 56 | |
---|
[1127] | 57 | self.running_time += time.time() - start_time |
---|
[1113] | 58 | if (self.checkpoint_path is not None |
---|
| 59 | and self.checkpoint_interval is not None |
---|
| 60 | and i % self.checkpoint_interval == 0): |
---|
[1127] | 61 | self.save_checkpoint() |
---|
[1113] | 62 | |
---|
| 63 | self.population = self.end_steps(self.population) |
---|
| 64 | |
---|
[1127] | 65 | def save_checkpoint(self): |
---|
| 66 | tmp_filepath = self.checkpoint_path+"_tmp" |
---|
| 67 | try: |
---|
| 68 | with open(tmp_filepath, "wb") as file: |
---|
| 69 | pickle.dump(self, file) |
---|
| 70 | os.replace(tmp_filepath, self.checkpoint_path) # ensures the new file was first saved OK (e.g. enough free space on device), then replace |
---|
| 71 | except Exception as ex: |
---|
| 72 | raise RuntimeError("Failed to save checkpoint '%s' (because: %s). This does not prevent the experiment from continuing, but let's stop here to fix the problem with saving checkpoints." % (tmp_filepath, ex)) |
---|
| 73 | |
---|
| 74 | |
---|
[1113] | 75 | @staticmethod |
---|
| 76 | def restore(path): |
---|
| 77 | with open(path) as file: |
---|
| 78 | res = pickle.load(file) |
---|
| 79 | return res |
---|