import json import logging from typing import List import matplotlib.pyplot as plt import numpy as np from . import Render from .. import Result, LocomotionActionAnalyzer log = logging.getLogger(__name__) def default(item): return item def sort(item): return item[0] def sort_sum(item): return sum(item) def plot(results: [[int]], ylabel: str, title: str, legend: (str,) = ("Locomotion", "Action")): size = len(results) data = list(zip(*results)) ind = np.arange(size) width = 0.85 loc = plt.bar(ind, data[0], width=width, color="red") act = plt.bar(ind, data[1], width=width, bottom=data[0], color="green") # ratio = plt.plot([1,2,3],[raw['locomotion_action_ratio'],raw['locomotion_relative'],raw['action_relative']], label="ratio", marker=".") # ratio = plt.plot(ind, data[4], label="ratio", marker=".") plt.ylabel(ylabel) plt.title(title) plt.xlabel("sessions") # plt.xticks(ind, log_ids) plt.xticks(ind, [""] * size) # plt.yticks(np.arange(0,1.1,0.10)) plt.legend((loc[0], act[0]), legend) plt.show() def plot_line(results: [[int]], ylabel="Ratio", title="Locomotion/Action "): size = len(results) data = list(zip(*results)) ind = np.arange(size) ratio = plt.plot(ind, data[0], label="ratio", marker=".") plt.ylabel(ylabel) plt.title(title) plt.xticks(ind, [""] * size) plt.show() def filter_results(raw_results: [Result], keys, sort=default) -> [[int]]: results = [] for result in raw_results: raw = result.get() results.append([raw[k] for k in keys]) results = sorted(results,key=sort) return results class LocomotionActionRender(Render): result_types = [LocomotionActionAnalyzer] class LocomotionActionAbsoluteRender(LocomotionActionRender): def render(self, results: List[Result], name=None): results = filter_results(self.filter(results), ['locomotion_sum', 'action_sum']) plot(results, "time", "abs loc/action") class LocomotionActionRelativeRender(LocomotionActionRender): def render(self, results: List[Result], name=None): results = filter_results(self.filter(results), ['locomotion_relative', 'action_relative']) plot(results, "fraction of time", "rel loc/action") class LocomotionActionRatioRender(LocomotionActionRender): def render(self, results: List[Result], name=None): results = filter_results(self.filter(results), ['locomotion_action_ratio']) plot_line(results, ylabel="Ratio", title="Locomotion/Action Ratio") class LocomotionActionRatioHistRender(LocomotionActionRender): def render(self, results: List[Result]): results = filter_results(self.filter(results), ['locomotion_action_ratio']) plt.title("locomotion/action") plt.xlabel("ratio") plt.ylabel("frequency") n, bins, patches = plt.hist([results], bins=len(results)) plt.show()