66 lines
1.7 KiB
Python
66 lines
1.7 KiB
Python
from typing import Callable
|
|
import matplotlib.pyplot as plt
|
|
from dataclasses import dataclass
|
|
from copy import deepcopy
|
|
|
|
MetricsContainer = dict[str, float]
|
|
|
|
@dataclass
|
|
class MetricConfig:
|
|
initial_value: float
|
|
period_calculator: Callable[[MetricsContainer], float]
|
|
plot: bool = True
|
|
|
|
def create_metrics_container(
|
|
metric_configs: dict[str, MetricConfig]
|
|
) -> MetricsContainer:
|
|
metrics_container = MetricsContainer()
|
|
|
|
for metric_name, metric_config in metric_configs.items():
|
|
metrics_container[metric_name] = metric_config.initial_value
|
|
|
|
return metrics_container
|
|
|
|
def plot_metric_histories(
|
|
metric_histories: list[MetricsContainer],
|
|
metric_configs: dict[str, MetricConfig]
|
|
):
|
|
plot_data = {}
|
|
|
|
for metric_name, metric_config in metric_configs.items():
|
|
if not metric_config.plot:
|
|
continue
|
|
|
|
plot_data[metric_name] = [metric_history[metric_name] for
|
|
metric_history in metric_histories]
|
|
|
|
fig, ax = plt.subplots(figsize=(14, 8))
|
|
|
|
for label, values in plot_data.items():
|
|
ax.plot(values, label=label)
|
|
|
|
ax.legend()
|
|
|
|
ax.set_title("Metrics Over Simulation Period")
|
|
ax.set_xlabel("Period")
|
|
ax.set_ylabel("Value")
|
|
|
|
ax.grid(True)
|
|
|
|
plt.show()
|
|
|
|
def simulate(
|
|
num_periods: int,
|
|
metric_configs: dict[str, MetricConfig]
|
|
):
|
|
metrics = create_metrics_container(metric_configs)
|
|
metric_histories = [deepcopy(metrics)]
|
|
|
|
for period in range(num_periods):
|
|
for metric_name, metric_config in metric_configs.items():
|
|
metrics[metric_name] = metric_config.period_calculator(metrics)
|
|
|
|
metric_histories.append(deepcopy(metrics))
|
|
|
|
plot_metric_histories(metric_histories, metric_configs)
|