investment_simulator/simulator.py
2024-08-08 11:03:14 +10:00

66 lines
1.7 KiB
Python

from typing import Callable
import matplotlib.pyplot as plt
from dataclasses import dataclass
from copy import deepcopy
MetricsContainer = dict[str, float]
@dataclass
class MetricConfig:
initial_value: float
period_calculator: Callable[[MetricsContainer], float]
plot: bool = True
def create_metrics_container(
metric_configs: dict[str, MetricConfig]
) -> MetricsContainer:
metrics_container = MetricsContainer()
for metric_name, metric_config in metric_configs.items():
metrics_container[metric_name] = metric_config.initial_value
return metrics_container
def plot_metric_histories(
metric_histories: list[MetricsContainer],
metric_configs: dict[str, MetricConfig]
):
plot_data = {}
for metric_name, metric_config in metric_configs.items():
if not metric_config.plot:
continue
plot_data[metric_name] = [metric_history[metric_name] for
metric_history in metric_histories]
fig, ax = plt.subplots(figsize=(14, 8))
for label, values in plot_data.items():
ax.plot(values, label=label)
ax.legend()
ax.set_title("Metrics Over Simulation Period")
ax.set_xlabel("Period")
ax.set_ylabel("Value")
ax.grid(True)
plt.show()
def simulate(
num_periods: int,
metric_configs: dict[str, MetricConfig]
):
metrics = create_metrics_container(metric_configs)
metric_histories = [deepcopy(metrics)]
for period in range(num_periods):
for metric_name, metric_config in metric_configs.items():
metrics[metric_name] = metric_config.period_calculator(metrics)
metric_histories.append(deepcopy(metrics))
plot_metric_histories(metric_histories, metric_configs)