source: framspy/evolalg_steps/utils/population_save.py @ 1333

Last change on this file since 1333 was 1185, checked in by Maciej Komosinski, 2 years ago

Renamed a module; new name is "evolalg_steps"

File size: 1.7 KB
RevLine 
[1185]1from evolalg_steps.base.step import Step
[1113]2import os
3from framsfiles import writer as framswriter
4
5
6class PopulationSave(Step):
[1139]7    def __init__(self, path, fields, provider=None, *args, **kwargs):
8        super(PopulationSave, self).__init__(*args, **kwargs)
[1113]9        self.path = path
10        self.provider = provider
11        self.fields = fields
12
13    @staticmethod
14    def ensure_dir(path):
15        directory = os.path.dirname(path)
16        if directory == "":
17            return
18        if not os.path.exists(directory):
19            os.makedirs(directory)
20
21    def call(self, population):
[1139]22        super(PopulationSave, self).call(population)
[1113]23        PopulationSave.ensure_dir(self.path)
24        provider = self.provider
25        if provider is None:
26            provider = population
27
28        #TODO instead of "fitness", write all fields used as fitness source with their original names (e.g. "velocity", "vertpos" etc.). In evaluation, set all attributes we get from Framsticks so that here we get all original names and values. Or, even better, introduce a dict field in Individual and assign to it everything that we get from Framsticks on evaluation (and add a filtering ability, i.e. when None - save all, when a list of field names - save only the enumerated fields)
29        with open(self.path, "w") as outfile:
30            for ind in provider:
31                keyval = {}
32                for k in self.fields: # construct a dictionary with criteria names and their values
33                    keyval[k] = getattr(ind, self.fields[k])
34                outfile.write(framswriter.from_collection({"_classname": "org", **keyval}))
35                outfile.write("\n")
36
[1126]37        print("Saved '%s' (%d)" % (self.path, len(provider)))
[1113]38        return population
Note: See TracBrowser for help on using the repository browser.