from typing import List import matplotlib.pyplot as plt import numpy as np from . import Render from .. import Result, LocomotionActionAnalyzer def plot(results: [[int]], ylabel: str, title: str, legend: (str,) = ("Locomotion", "Action")): size = len(results) data = list(zip(*results)) ind = np.arange(size) width = 0.85 loc = plt.bar(ind, data[0], width=width, color="red") act = plt.bar(ind, data[1], width=width, bottom=data[0], color="green") # ratio = plt.plot([1,2,3],[raw['locomotion_action_ratio'],raw['locomotion_relative'],raw['action_relative']], label="ratio", marker=".") # ratio = plt.plot(ind, data[4], label="ratio", marker=".") plt.ylabel(ylabel) plt.title(title) plt.xlabel("sessions") # plt.xticks(ind, log_ids) plt.xticks(ind, [""] * size) # plt.yticks(np.arange(0,1.1,0.10)) plt.legend((loc[0], act[0]), legend) plt.show() def plot_line(results: [[int]], ylabel="Ratio", title="Locomotion/Action "): size = len(results) data = list(zip(*results)) ind = np.arange(size) ratio = plt.plot(ind, data[0], label="ratio", marker=".") plt.ylabel(ylabel) plt.title(title) plt.xticks(ind, [""] * size) plt.show() def filter_results(raw_results: [Result], keys) -> [[int]]: results = [] for result in raw_results: raw = result.get() results.append([raw[k] for k in keys]) return results class LocomotionActionRender(Render): result_types = [LocomotionActionAnalyzer] class LocomotionActionAbsoluteRender(LocomotionActionRender): def render(self, results: List[Result], name=None): results = filter_results(self.filter(results), ['locomotion_sum', 'action_sum']) plot(results, "time", "abs loc/action") class LocomotionActionRelativeRender(LocomotionActionRender): def render(self, results: List[Result], name=None): results = filter_results(self.filter(results), ['locomotion_relative', 'action_relative']) plot(results, "fraction of time", "rel loc/action") class LocomotionActionRatioRender(LocomotionActionRender): def render(self, results: List[Result], name=None): results = filter_results(self.filter(results), ['locomotion_action_ratio']) plot_line(results, ylabel="Ratio", title="Locomotion/Action Ratio")