project/analysis/analyzers/render/locomotion.py

88 lines
2.7 KiB
Python

import json
import logging
from typing import List
import matplotlib.pyplot as plt
import numpy as np
from . import Render
from .. import Result, LocomotionActionAnalyzer
log = logging.getLogger(__name__)
def default(item):
return item
def sort(item):
return item[0]
def sort_sum(item):
return sum(item)
def plot(results: [[int]], ylabel: str, title: str, legend: (str,) = ("Locomotion", "Action")):
size = len(results)
data = list(zip(*results))
ind = np.arange(size)
width = 0.85
loc = plt.bar(ind, data[0], width=width, color="red")
act = plt.bar(ind, data[1], width=width, bottom=data[0], color="green")
# ratio = plt.plot([1,2,3],[raw['locomotion_action_ratio'],raw['locomotion_relative'],raw['action_relative']], label="ratio", marker=".")
# ratio = plt.plot(ind, data[4], label="ratio", marker=".")
plt.ylabel(ylabel)
plt.title(title)
plt.xlabel("sessions")
# plt.xticks(ind, log_ids)
plt.xticks(ind, [""] * size)
# plt.yticks(np.arange(0,1.1,0.10))
plt.legend((loc[0], act[0]), legend)
plt.show()
def plot_line(results: [[int]], ylabel="Ratio", title="Locomotion/Action "):
size = len(results)
data = list(zip(*results))
ind = np.arange(size)
ratio = plt.plot(ind, data[0], label="ratio", marker=".")
plt.ylabel(ylabel)
plt.title(title)
plt.xticks(ind, [""] * size)
plt.show()
def filter_results(raw_results: [Result], keys, sort=default) -> [[int]]:
results = []
for result in raw_results:
raw = result.get()
results.append([raw[k] for k in keys])
results = sorted(results,key=sort)
return results
class LocomotionActionRender(Render):
result_types = [LocomotionActionAnalyzer]
class LocomotionActionAbsoluteRender(LocomotionActionRender):
def render(self, results: List[Result], name=None):
results = filter_results(self.filter(results), ['locomotion_sum', 'action_sum'])
plot(results, "time", "abs loc/action")
class LocomotionActionRelativeRender(LocomotionActionRender):
def render(self, results: List[Result], name=None):
results = filter_results(self.filter(results), ['locomotion_relative', 'action_relative'])
plot(results, "fraction of time", "rel loc/action")
class LocomotionActionRatioRender(LocomotionActionRender):
def render(self, results: List[Result], name=None):
results = filter_results(self.filter(results), ['locomotion_action_ratio'])
plot_line(results, ylabel="Ratio", title="Locomotion/Action Ratio")
class LocomotionActionRatioHistRender(LocomotionActionRender):
def render(self, results: List[Result]):
results = filter_results(self.filter(results), ['locomotion_action_ratio'])
plt.title("locomotion/action")
plt.xlabel("ratio")
plt.ylabel("frequency")
n, bins, patches = plt.hist([results], bins=len(results))
plt.show()