| 1 | import os |
|---|
| 2 | from typing import List, Callable, Union |
|---|
| 3 | |
|---|
| 4 | from evolalg_steps.base.step import Step |
|---|
| 5 | import pickle |
|---|
| 6 | import time |
|---|
| 7 | |
|---|
| 8 | from evolalg_steps.base.union_step import UnionStep |
|---|
| 9 | from evolalg_steps.selection.selection import Selection |
|---|
| 10 | from evolalg_steps.utils.stable_generation import StableGeneration |
|---|
| 11 | import logging |
|---|
| 12 | |
|---|
| 13 | class Experiment: |
|---|
| 14 | def __init__(self, init_population: List[Callable], |
|---|
| 15 | selection: Selection, |
|---|
| 16 | new_generation_steps: List[Union[Callable, Step]], |
|---|
| 17 | generation_modification: List[Union[Callable, Step]], |
|---|
| 18 | end_steps: List[Union[Callable, Step]], |
|---|
| 19 | population_size, |
|---|
| 20 | checkpoint_path=None, checkpoint_interval=None): |
|---|
| 21 | |
|---|
| 22 | self.init_population = init_population |
|---|
| 23 | self.running_time = 0 |
|---|
| 24 | self.step = StableGeneration( |
|---|
| 25 | selection=selection, |
|---|
| 26 | steps=new_generation_steps, |
|---|
| 27 | population_size=population_size) |
|---|
| 28 | self.generation_modification = UnionStep(generation_modification) |
|---|
| 29 | |
|---|
| 30 | self.end_steps = UnionStep(end_steps) |
|---|
| 31 | |
|---|
| 32 | self.checkpoint_path = checkpoint_path |
|---|
| 33 | self.checkpoint_interval = checkpoint_interval |
|---|
| 34 | self.generation = 0 |
|---|
| 35 | self.population = None |
|---|
| 36 | |
|---|
| 37 | def init(self): |
|---|
| 38 | self.generation = 0 |
|---|
| 39 | for s in self.init_population: |
|---|
| 40 | if isinstance(s, Step): |
|---|
| 41 | s.init() |
|---|
| 42 | |
|---|
| 43 | self.step.init() |
|---|
| 44 | self.generation_modification.init() |
|---|
| 45 | self.end_steps.init() |
|---|
| 46 | self.population = [] |
|---|
| 47 | for s in self.init_population: |
|---|
| 48 | self.population = s(self.population) |
|---|
| 49 | |
|---|
| 50 | def run(self, num_generations): |
|---|
| 51 | for i in range(self.generation + 1, num_generations + 1): |
|---|
| 52 | start_time = time.time() |
|---|
| 53 | self.generation = i |
|---|
| 54 | self.population = self.step(self.population) |
|---|
| 55 | self.population = self.generation_modification(self.population) |
|---|
| 56 | |
|---|
| 57 | self.running_time += time.time() - start_time |
|---|
| 58 | if (self.checkpoint_path is not None |
|---|
| 59 | and self.checkpoint_interval is not None |
|---|
| 60 | and i % self.checkpoint_interval == 0): |
|---|
| 61 | self.save_checkpoint() |
|---|
| 62 | |
|---|
| 63 | self.population = self.end_steps(self.population) |
|---|
| 64 | |
|---|
| 65 | def save_checkpoint(self): |
|---|
| 66 | tmp_filepath = self.checkpoint_path+"_tmp" |
|---|
| 67 | try: |
|---|
| 68 | with open(tmp_filepath, "wb") as file: |
|---|
| 69 | pickle.dump(self, file) |
|---|
| 70 | os.replace(tmp_filepath, self.checkpoint_path) # ensures the new file was first saved OK (e.g. enough free space on device), then replace |
|---|
| 71 | except Exception as ex: |
|---|
| 72 | raise RuntimeError("Failed to save checkpoint '%s' (because: %s). This does not prevent the experiment from continuing, but let's stop here to fix the problem with saving checkpoints." % (tmp_filepath, ex)) |
|---|
| 73 | |
|---|
| 74 | |
|---|
| 75 | @staticmethod |
|---|
| 76 | def restore(path): |
|---|
| 77 | with open(path) as file: |
|---|
| 78 | res = pickle.load(file) |
|---|
| 79 | return res |
|---|