source: framspy/evolalg_steps/fitness/fitness_step.py @ 1302

Last change on this file since 1302 was 1185, checked in by Maciej Komosinski, 2 years ago

Renamed a module; new name is "evolalg_steps"

File size: 1.7 KB
RevLine 
[1113]1from typing import Dict, List
2
[1185]3from evolalg_steps.base.frams_step import FramsStep
4from evolalg_steps.base.individual import Individual
[1113]5import frams
6
7
8class FitnessStep(FramsStep):
9    def __init__(self, frams_lib, fields: Dict, fields_defaults: Dict, commands: List[str] = None,
[1139]10                 vectorized: bool = True, evaluation_count=None, *args, **kwargs):
[1113]11
[1139]12        super().__init__(frams_lib, commands, *args, **kwargs)
[1113]13        self.fields = fields
14        self.fields_defaults = fields_defaults
15        self.vectorized = vectorized
16        self.evaluation_count = evaluation_count
17        self.evaluation_count_original = None  # to be able to restore to original value after it is changed
18
19    def pre(self):
20        if self.evaluation_count is not None:
21            self.evaluation_count_original = frams.ExpProperties.evalcount._value()  # store original value and restore it in post()
22            frams.ExpProperties.evalcount = self.evaluation_count
23
24    def post(self):
25        if self.evaluation_count is not None:
26            frams.ExpProperties.evalcount = self.evaluation_count_original
27            self.evaluation_count_original = None
28
29    def call(self, population: List[Individual]):
[1139]30        super(FitnessStep, self).call(population)
[1113]31        if self.vectorized:
32            data = self.frams.evaluate([_.genotype for _ in population])
33        else:
34            data = [self.frams.evaluate([_.genotype]) for _ in population]
35
36        for ind, d in zip(population, data):
37            for k, v in self.fields.items():
38                try:
39                    setattr(ind, v, d["evaluations"][""][k])
40                except:
41                    setattr(ind, v, self.fields_defaults[k])
42        return population
Note: See TracBrowser for help on using the repository browser.