Upload submission
This commit is contained in:
commit
75e019efc0
159
.gitignore
vendored
Normal file
159
.gitignore
vendored
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
junit/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv’s lock file may be unsuitable for version control.
|
||||||
|
Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.envrc
|
||||||
|
.venv
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
bin/
|
||||||
|
Scripts/
|
||||||
|
pyvenv.cfg
|
||||||
|
*.pyvenv
|
||||||
|
.dockerignore
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# VSCode
|
||||||
|
.vscode/
|
||||||
|
*.code-workspace
|
||||||
|
|
||||||
|
# Editors and IDEs
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
*.sublime-workspace
|
||||||
|
*.sublime-project
|
||||||
|
|
||||||
|
# User-specific stuff:
|
||||||
|
*.DS_Store
|
||||||
|
*.AppleDouble
|
||||||
|
*.LSOverride
|
||||||
|
Icon?
|
||||||
|
._*
|
||||||
|
.Spotlight-V100
|
||||||
|
.Trashes
|
||||||
|
ehthumbs.db
|
||||||
|
Thumbs.db
|
||||||
|
Desktop.ini
|
||||||
|
|
||||||
|
mlruns/
|
127
README.md
Normal file
127
README.md
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
# Mantel Group Technical Challenge
|
||||||
|
|
||||||
|
This repository contains a productionised version (including testing and benchmarking) of the Kohonen Self Organising Map model implemented in the problem specification. Please let me know if there are any issues running the code.
|
||||||
|
|
||||||
|
## Setting up this repository
|
||||||
|
|
||||||
|
Firstly, due to time constraints, there are components of this codebase that are not abstracted in such a way that they are conducive to collaborative work. There are certain hardcoded elemtns - particularly in the testing and benchmarking modules. I have chosen what I believe to be reasonable defaults for running testing and benchmarking on an average machine.
|
||||||
|
|
||||||
|
### Setup environment
|
||||||
|
|
||||||
|
Python Poetry is used as a package manager for this project. It can be isntalled by following the [documentation](https://python-poetry.org/docs/).
|
||||||
|
Once Poetry is installed, run `poetry install` to create a virtual environment with the necessary dependencies.
|
||||||
|
|
||||||
|
### Running tests
|
||||||
|
|
||||||
|
Run tests using `pytest tests`. Tests are contained in `./tests`.
|
||||||
|
|
||||||
|
### Running benchmarking
|
||||||
|
|
||||||
|
Run benchmarks using `pytest benchmarks`. Running benchmarks will create MLFlow experiments that can be inspected in the MLFlow dashboard that will be served at `http://127.0.0.1:5000/` after running `mlflow ui`. Benchmarks are contained in `./benchmarks`.
|
||||||
|
|
||||||
|
## Mantel Code Assessment
|
||||||
|
|
||||||
|
Whilst the tests, benchmarking and organisation of my Kohonen Network implementation should intimate a more *production-ready* codebase, I will highlight some key improvements I have made to the implementation.
|
||||||
|
|
||||||
|
### Using ASCII variables for variable names
|
||||||
|
The components of a Kohonen Network can all be expressed using mathematical notation, so it seems logical to use these same mathematical symbols in one's code. However, I would suggest two key reasons to Sam why using non-ASCII characters are usually undesirbale.
|
||||||
|
- Developers reading and maintaining the code might not be familiar with the mathematical formulae and therefore, would be unable to make sense of semantic sense of such variables.
|
||||||
|
- Non-ASCII characters are not typically found on standard keyboard, making typing such variable names inconvenient
|
||||||
|
|
||||||
|
Thus, I always use ASCII characters and make a careful effort to utilise descriptive and unambiguous variable names rather than short and forgettable variable names.
|
||||||
|
|
||||||
|
### Packaging
|
||||||
|
|
||||||
|
Sam's implementation expects to be executed as a Python module due to the `if __name__ == '__main__':` block. Whilst this is okay during first-pass development, I would ask Sam how he would advise others to use his module in this way, and it would soon come to light that there is no CLI argument parsing or other method by which someone could use the module on their own data.
|
||||||
|
|
||||||
|
I would suggest that Sam package his code in a way similar to how I have done - where I've created a `models` package that includes a `kohonen_network.py` module and directly exposes the `train_kohon_network`, allowing for anyone to easily `from models import train_kohon_network` to train their own model on their own data.
|
||||||
|
|
||||||
|
|
||||||
|
### Modularisation
|
||||||
|
|
||||||
|
Sam's `train` function is dense and does not make use of any helper functions. For a complex algorithm such as training a Kohonen Network, there are several downsides to having a monolothic fucntion:
|
||||||
|
|
||||||
|
- Reduced readability
|
||||||
|
- More difficult to make localised changes
|
||||||
|
- Isolating bugs is more challenging
|
||||||
|
- Cannot test individual components
|
||||||
|
- Cannot reuse code
|
||||||
|
|
||||||
|
I would advise Sam to consider the example of initialising the model weights. While my `initialise_random_tensor` might have a near-identical implementation to his, I have extended it to support any arbitrary dimensionality and have encapsulated the `numpy` implementation details. This means that `initialise_random_tensor` could be used in any other model that needs to randomly initialise a tensor and the `numpy` implementation could, for example, be swapped out with another implementation such as `jax.numpy` for GPU acceleration, without having to manually update each of those functions.
|
||||||
|
|
||||||
|
### Typing and comments
|
||||||
|
|
||||||
|
Code should tell a story to the reader and be as clear and simple as possible to follow. I would suggest to Sam that he would have better control over the story he is trying to tell if he used type hinting and comments. However, these should only be used if they *add signal* to the code. Commenting `# Add numbers A and B` above `return sum(A, B)` does not add any signal - it dilutes the current signal.
|
||||||
|
|
||||||
|
Let's consider my function:
|
||||||
|
|
||||||
|
```python
|
||||||
|
Node = namedtuple("Node", ['i', 'j'])
|
||||||
|
|
||||||
|
def _find_best_matching_unit(weights: NDArray[np.float32], x: NDArray[np.float32],
|
||||||
|
width: int, height: int) -> Node:
|
||||||
|
"""Finds the network node that is currently most similar to `x`."""
|
||||||
|
bmu = np.argmin(jnp.sum((weights - x) ** 2, axis=2))
|
||||||
|
return Node(*jnp.unravel_index(bmu, (height, width)))
|
||||||
|
```
|
||||||
|
|
||||||
|
By abstracting this implementation into a helper function, I get the same modularisation benefits as discussed above because someone following a story that uses `_find_best_matching_unit` only needs to understand what it does to continue the story, not how it does what it does. I have done a few things to communicate *what this function does* as concisely as possible.
|
||||||
|
|
||||||
|
- Including argument type hints provides important context for the reader/user to ascertain what the author expects the function to operate on.
|
||||||
|
- Declaring a named tuple return type, `Node`, informs the reader what the function returns and is more descriptive than simply `Tuple[int, int]`.
|
||||||
|
- Adding a docstring that phrases the function's logic in natural language can aid in the reader's understanding without needing to read its implementation.
|
||||||
|
|
||||||
|
### Performance
|
||||||
|
|
||||||
|
Sam's implementation could benefit considerbaly from some vectorised operations instead of iterating over each node and updating its weight. Vectorised operations in `numpy` are implemented to leverage highly efficient, low-level code that can utilise hardware acceleration - often resulting in large speed-ups.
|
||||||
|
|
||||||
|
In my implementation, I abstracted out `_update_weights` so that I could wrap it in [JAX](https://github.com/google/jax) Just-In-Time compilation to compile the function using XLA (Accelerated Linear Algebra.) The below image shows the output of a benchmark comparing my implementation to Sam's implementation for random parameters and inputs. In this benchmark, JAX was configured to use my CPU - one could configure the module to use JAX on a GPU for even greater speed gains. While JAX does add some overhead and may be less efficient for very simple networks, it is orders of magnitude faster for complex networks.
|
||||||
|
|
||||||
|
![performance](documents/execution_time_comparison.png)
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
As part of productionising this application, it is critical to rigorously test the implementation so that any bugs are and issues are identified before launching to production, any bugs are picked up after making changes post launching to production and writing tests often highlights code-smells - encouraging better software development practices.
|
||||||
|
|
||||||
|
|
||||||
|
I have used `pytest` to write unit and integration tests. In addition, I have used the [Hypothesis](https://hypothesis.readthedocs.io/en/latest/index.html) library to employ property-based testing. Let's consider the below example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@given(
|
||||||
|
X=st.one_of(
|
||||||
|
st.lists(st.floats(allow_nan=False, allow_infinity=False, min_value=-1e6, max_value=1e6)), # Not 2D
|
||||||
|
st.just(np.array([])) # Empty
|
||||||
|
),
|
||||||
|
width=st.integers(max_value=MINIMUM_NETWORK_DIMENSION - 1), # Includes non-positive
|
||||||
|
height=st.integers(max_value=MINIMUM_NETWORK_DIMENSION - 1), # Includes non-positive
|
||||||
|
num_iterations=st.integers(max_value=0), # Includes non-positive
|
||||||
|
initial_learning_rate=st.one_of(
|
||||||
|
st.floats(max_value=0.0, allow_nan=False, allow_infinity=False), # Includes non-positive and NaN/inf
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def test_create_kohonen_params_invalid_failure(X, width, height, num_iterations, initial_learning_rate):
|
||||||
|
result = create_kohonen_params(X, width, height, num_iterations, initial_learning_rate)
|
||||||
|
assert isinstance(result, Failure) and isinstance(result.failure(), ValueError)
|
||||||
|
```
|
||||||
|
|
||||||
|
Rather than trying to think of differnt edge cases, `hypothesis` will simulate many different inputs that we expect to break the `create_kohonen_params` function. Importantly, `hypothesis` will try and break our test, i.e. find inputs that do not result in a `Failure` return type from `create_kohonen_params`, and it will then find the minimum failing example so the set of parameters that broke the test is as obvious as possible. If this test passes, I can be confident that `create_kohonen_params` is correctly returning a `ValueError` failure whenever it receives invalid inputs (e.g. negative width.)
|
||||||
|
|
||||||
|
The concept of a `Result` return type from a function is implemented by the [Returns](https://returns.readthedocs.io/en/latest/pages/result.html) library and is a functional nomad declaring that as a `Result`, it will either have a value if the function succeeded, or one of a defined set of exceptions if it failed. This enforces deliberate and explicit exception-handling/propagation. In the above example, we declare explicitly that `create_kohonen_params` should return a `ValueError` if it receives any invalid input.
|
||||||
|
|
||||||
|
### Benchmarking
|
||||||
|
|
||||||
|
I also used `hypothesis` to simulate different scenarios and thus evaluate the performance of both my implementation and Sam's. To assist in the benchmarking, I used MLFlow to allow me to visually compare the models and inspect different metrics. Another benefit of MLFlow is the ability to inspect metrics across iterations. For example, in the below image, we can see that both neighbourhood radius and learning rate reduce exponentially. This is a solid sanity check. Using a tool such as MLFlow makes comparing experiments and collaborating on models far easier. My inclusion of MLFlow in this codebase is pretty barebones and does not have any secret injection. I would encourage Sam to use MLFlow early into model development so he can measure and quantify how different model implementations and versions perform.
|
||||||
|
|
||||||
|
![MLFlow](documents/mlflow.png)
|
||||||
|
|
||||||
|
### Deployment
|
||||||
|
|
||||||
|
Since one of the primary benefits of the Kohonen Map is to perform dimensionality reduction, it is possible that all that need be deployed is the `models` package. One would likely want to name the package more descriptive, such as `kohonen_network`. It would be straightforward to use GitHub Actions to automatically update the package and submit it to a PyPi repository, allowing people to use `train_kohon_network` and perform dimensionality reduction in their own pipelines.
|
||||||
|
|
||||||
|
Alternatively, the package could be imported into a Python Flask application that is hosted in a Docker container in a Kubernetes cluster or on a bare-metal server. This would make for an endpoint that any other services could use. Rather than a self-managed Flask app, one could use Databricks to solve the model, which has the added benefit of natively fitting into Databricks pipelines.
|
||||||
|
|
||||||
|
A few example use-cases (beyond dimensionality reduction as a data preprocessing step) are:
|
||||||
|
|
||||||
|
- Image Processing. Kohonen Networks can cluster similar colors in an image for use in image compression, reducing the color palette to essential colors while maintaining visual fidelity.
|
||||||
|
- Detecting Abnormal Behavior in Industrial Systems. Kohonen Networks can monitor data streams from sensors in industrial settings, such as temperature or pressure readings, to detect deviations from standard operating conditions that may indicate equipment malfunctions or safety hazards.
|
||||||
|
- Market Segmentation. Kohonen networks can be used to cluster customers based on purchasing behavior and preferences, helping businesses tailor marketing strategies.
|
70
benchmarks/benchmark_performance.py
Normal file
70
benchmarks/benchmark_performance.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
from hypothesis import given, settings, HealthCheck, note, event, strategies as st
|
||||||
|
from returns.result import Success
|
||||||
|
import datetime
|
||||||
|
import numpy as np
|
||||||
|
import mlflow
|
||||||
|
import pytest
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from tests.models.common import generate_kohonen_samples, MINIMUM_NETWORK_DIMENSION
|
||||||
|
from models import train_kohonen_network, train_kohonen_network_sam, create_kohonen_params
|
||||||
|
|
||||||
|
mlflow_experiment_name = "Kohonen_Network_Benchmark"
|
||||||
|
timestamp = datetime.datetime.now(datetime.UTC).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||||
|
mlflow.set_experiment(f"{mlflow_experiment_name}_{timestamp}")
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def mlflow_context():
|
||||||
|
context = []
|
||||||
|
yield context
|
||||||
|
|
||||||
|
df = pd.DataFrame(context)
|
||||||
|
|
||||||
|
# Plot scatter chart of execution_time vs msam_execution_time
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
plt.scatter(df['execution_time'], df['sam_execution_time'], alpha=0.7, color='blue')
|
||||||
|
plt.xlabel('Execution Time')
|
||||||
|
plt.ylabel('Sam Execution Time')
|
||||||
|
plt.title(f'Comparison of Execution Times (n = {len(context)})')
|
||||||
|
plt.plot([df['execution_time'].min(), df['sam_execution_time'].max()],
|
||||||
|
[df['execution_time'].min(), df['sam_execution_time'].max()], 'k--') # Diagonal line
|
||||||
|
plt.grid(True)
|
||||||
|
plt.savefig('benchmarks/execution_time_comparison.png')
|
||||||
|
|
||||||
|
@given(
|
||||||
|
data=st.data(),
|
||||||
|
feature_size=st.integers(min_value=1, max_value=10),
|
||||||
|
width=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=200),
|
||||||
|
height=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=200),
|
||||||
|
num_iterations=st.integers(min_value=10, max_value=1000),
|
||||||
|
initial_learning_rate=st.floats(min_value=1e-3, max_value=1.0, allow_nan=False, allow_infinity=False)
|
||||||
|
)
|
||||||
|
@settings(max_examples=20, deadline=None, suppress_health_check=(HealthCheck.too_slow,))
|
||||||
|
def benchmark_kohonen_networks_performance_mlflow(mlflow_context, data, feature_size, width, height, num_iterations, initial_learning_rate):
|
||||||
|
np.random.seed(42)
|
||||||
|
timestamp = datetime.datetime.now(datetime.UTC).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||||
|
|
||||||
|
X = data.draw(generate_kohonen_samples(feature_size))
|
||||||
|
|
||||||
|
# Train Kohonen Network and record metrics to MLFlow
|
||||||
|
with mlflow.start_run(run_name=f"train_kohonen_network_{timestamp}") as run:
|
||||||
|
params = create_kohonen_params(X, width, height, num_iterations, initial_learning_rate)
|
||||||
|
assert isinstance(params, Success)
|
||||||
|
_ = train_kohonen_network(X, params.unwrap(), use_mlflow=True)
|
||||||
|
assert isinstance(_, Success)
|
||||||
|
|
||||||
|
# Train Kohonen Network (Sam) and record metrics to MLFlow
|
||||||
|
with mlflow.start_run(run_name=f"train_kohonen_network_sam_{timestamp}") as sam_run:
|
||||||
|
_ = train_kohonen_network_sam(X, num_iterations, width, height,
|
||||||
|
feature_size, initial_learning_rate, use_mlflow=True)
|
||||||
|
|
||||||
|
# Read MLFlow to compare execution times
|
||||||
|
client = mlflow.tracking.MlflowClient()
|
||||||
|
execution_time = client.get_metric_history(run.info.run_id, "execution_time")[-1].value
|
||||||
|
sam_execution_time = client.get_metric_history(sam_run.info.run_id, "execution_time")[-1].value
|
||||||
|
|
||||||
|
mlflow_context.append({
|
||||||
|
"sam_execution_time": sam_execution_time,
|
||||||
|
"execution_time": execution_time
|
||||||
|
})
|
BIN
documents/execution_time_comparison.png
Normal file
BIN
documents/execution_time_comparison.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 38 KiB |
274
documents/kohonen.ipynb
Normal file
274
documents/kohonen.ipynb
Normal file
File diff suppressed because one or more lines are too long
BIN
documents/mlflow.png
Normal file
BIN
documents/mlflow.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 45 KiB |
2171
poetry.lock
generated
Normal file
2171
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
27
pyproject.toml
Normal file
27
pyproject.toml
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
[tool.poetry]
|
||||||
|
name = "mantel"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
authors = ["Harry Stuart <hj.stuart0003@gmail.com>"]
|
||||||
|
readme = "README.md"
|
||||||
|
packages = [{ include = "models", from = "src" }, { include = "tests" }]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.12"
|
||||||
|
numpy = "^1.26.4"
|
||||||
|
jax = "^0.4.28"
|
||||||
|
returns = "^0.22.0"
|
||||||
|
matplotlib = "^3.9.0"
|
||||||
|
jaxlib = "^0.4.28"
|
||||||
|
chex = "^0.1.86"
|
||||||
|
hypothesis = "^6.103.1"
|
||||||
|
mlflow = "^2.13.2"
|
||||||
|
pytest = "^8.2.2"
|
||||||
|
seaborn = "^0.13.2"
|
||||||
|
pandas = "^2.2.2"
|
||||||
|
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
16
pytest.ini
Normal file
16
pytest.ini
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# pytest.ini
|
||||||
|
[pytest]
|
||||||
|
minversion = 6.0
|
||||||
|
addopts = -ra
|
||||||
|
testpaths =
|
||||||
|
tests
|
||||||
|
benchmarks
|
||||||
|
python_files =
|
||||||
|
test_*.py
|
||||||
|
benchmark_*.py
|
||||||
|
python_functions =
|
||||||
|
test_*
|
||||||
|
benchmark_*
|
||||||
|
filterwarnings =
|
||||||
|
ignore::DeprecationWarning
|
||||||
|
ignore::FutureWarning
|
3
src/models/__init__.py
Normal file
3
src/models/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .kohonen_network import train_kohonen_network, create_kohonen_params
|
||||||
|
from .kohonen_network_sam import train_kohonen_network_sam
|
||||||
|
from .kohonen_network import KohonenParams
|
8
src/models/common.py
Normal file
8
src/models/common.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
import tempfile
|
||||||
|
import mlflow
|
||||||
|
import os
|
||||||
|
|
||||||
|
def initialise_random_tensor(*dims: int) -> NDArray[np.float32]:
|
||||||
|
return np.random.rand(*dims).astype(np.float32)
|
149
src/models/kohonen_network.py
Normal file
149
src/models/kohonen_network.py
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax
|
||||||
|
import io
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from returns.result import Result, Success, Failure, safe
|
||||||
|
from typing import Optional
|
||||||
|
from collections import namedtuple
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
import chex
|
||||||
|
import mlflow
|
||||||
|
|
||||||
|
from .common import initialise_random_tensor
|
||||||
|
|
||||||
|
@chex.dataclass(frozen=True)
|
||||||
|
class KohonenParams:
|
||||||
|
initial_neighbourhood_radius: float
|
||||||
|
initial_learning_rate: float
|
||||||
|
exp_time_const: float
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
num_iterations: int
|
||||||
|
feature_size: int
|
||||||
|
|
||||||
|
Node = namedtuple("Node", ['i', 'j'])
|
||||||
|
|
||||||
|
def _find_best_matching_unit(weights: NDArray[np.float32], x: NDArray[np.float32],
|
||||||
|
width: int, height: int) -> Node:
|
||||||
|
"""Finds the network node that is currently most similar to `x`."""
|
||||||
|
bmu = np.argmin(jnp.sum((weights - x) ** 2, axis=2))
|
||||||
|
return Node(*jnp.unravel_index(bmu, (height, width)))
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=("params",))
|
||||||
|
def _update_weights(
|
||||||
|
x: NDArray[np.float32],
|
||||||
|
weights: NDArray[np.float32],
|
||||||
|
neighbourhood_radius: float,
|
||||||
|
learning_rate: float,
|
||||||
|
params: KohonenParams
|
||||||
|
) -> NDArray[np.float32]:
|
||||||
|
"""Updates weights given the current input vector, `x`."""
|
||||||
|
weights = weights.copy()
|
||||||
|
|
||||||
|
bmu = _find_best_matching_unit(weights, x, params.width, params.height)
|
||||||
|
|
||||||
|
i_indices, j_indices = jnp.indices([params.height, params.width])
|
||||||
|
distances = jnp.sqrt((i_indices - bmu.i) ** 2 + (j_indices - bmu.j) ** 2)
|
||||||
|
influences = jnp.exp(-distances ** 2 / (2 * (neighbourhood_radius ** 2)))
|
||||||
|
weights += learning_rate * influences[..., jnp.newaxis] * (x - weights)
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def _calculate_neighbourhood_radius(iteration: int, params: KohonenParams):
|
||||||
|
"""Uses exponential decay to compute radius about best matching unit within
|
||||||
|
which nodes will be considered neighbours for a given iteration."""
|
||||||
|
return (params.initial_neighbourhood_radius *
|
||||||
|
np.exp(-iteration / params.exp_time_const))
|
||||||
|
|
||||||
|
def _calculate_learning_rate(iteration: int, params: KohonenParams):
|
||||||
|
"""Employes exponential decay to calculate learning rate for a given iteration."""
|
||||||
|
return (params.initial_learning_rate *
|
||||||
|
np.exp(-iteration / params.exp_time_const))
|
||||||
|
|
||||||
|
@safe
|
||||||
|
def train_kohonen_network(
|
||||||
|
X: NDArray[np.float32],
|
||||||
|
params: KohonenParams,
|
||||||
|
use_mlflow: bool = False
|
||||||
|
) -> NDArray[np.float32]:
|
||||||
|
"""Trains a Kohonen Network (self-organising map) and logs to mlflow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: An (m, n) matrix of training samples where m is the number of samples
|
||||||
|
and n is the feature vector size.
|
||||||
|
params: A KohonenTrainingParams object containing parameters
|
||||||
|
for model training.
|
||||||
|
use_mlflow: Toggle MLFlow logging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A (height, width, n) tensor representing the network of learned weights,
|
||||||
|
or an exception.
|
||||||
|
"""
|
||||||
|
if use_mlflow:
|
||||||
|
mlflow.log_params({
|
||||||
|
"initial_neighbourhood_radius": params.initial_neighbourhood_radius,
|
||||||
|
"initial_learning_rate": params.initial_learning_rate,
|
||||||
|
"exp_time_const": params.exp_time_const,
|
||||||
|
"width": params.width,
|
||||||
|
"height": params.height,
|
||||||
|
"num_iterations": params.num_iterations,
|
||||||
|
"feature_size": params.feature_size
|
||||||
|
})
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
weights = initialise_random_tensor(params.height, params.width, params.feature_size)
|
||||||
|
|
||||||
|
for t in range(params.num_iterations):
|
||||||
|
neighbourhood_radius = _calculate_neighbourhood_radius(t, params)
|
||||||
|
learning_rate = _calculate_learning_rate(t, params)
|
||||||
|
|
||||||
|
if use_mlflow:
|
||||||
|
mlflow.log_metrics({
|
||||||
|
"neighbourhood_radius": neighbourhood_radius,
|
||||||
|
"learning_rate": learning_rate
|
||||||
|
}, step=t)
|
||||||
|
|
||||||
|
for x in X:
|
||||||
|
# For current iteration parameters, updates network weights using current sample x
|
||||||
|
weights = _update_weights(x, weights, neighbourhood_radius, learning_rate, params)
|
||||||
|
|
||||||
|
if use_mlflow:
|
||||||
|
execution_time = time.perf_counter() - start_time
|
||||||
|
mlflow.log_metric("execution_time", execution_time)
|
||||||
|
|
||||||
|
# Convert from JAX array to NumPy array
|
||||||
|
return np.array(weights)
|
||||||
|
|
||||||
|
def create_kohonen_params(
|
||||||
|
X: NDArray[np.float32],
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
num_iterations: int,
|
||||||
|
initial_learning_rate: float = 0.1
|
||||||
|
) -> Result[KohonenParams, ValueError]:
|
||||||
|
"""Creates safe and valid KohonenParams using desired hyperparameters or returns
|
||||||
|
an error is given hyperparameters are not conducive."""
|
||||||
|
if width <= 2:
|
||||||
|
return Failure(ValueError("Width must be greater than two."))
|
||||||
|
if height <= 2:
|
||||||
|
return Failure(ValueError("Height must be greater than two."))
|
||||||
|
if num_iterations <= 0:
|
||||||
|
return Failure(ValueError("Number of iterations must be greater than zero."))
|
||||||
|
if initial_learning_rate <= 0:
|
||||||
|
return Failure(ValueError("Initial learning rate must be greater than zero."))
|
||||||
|
if len(X.shape) != 2:
|
||||||
|
return Failure(ValueError("Can only create KohonenParams for two dimensional input data."))
|
||||||
|
|
||||||
|
initial_neighbourhood_radius = max(width, height) / 2
|
||||||
|
return Success(KohonenParams(
|
||||||
|
initial_neighbourhood_radius=initial_neighbourhood_radius,
|
||||||
|
initial_learning_rate=initial_learning_rate,
|
||||||
|
exp_time_const=num_iterations / np.log(initial_neighbourhood_radius),
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_iterations=num_iterations,
|
||||||
|
feature_size=X.shape[-1]
|
||||||
|
))
|
30
src/models/kohonen_network_sam.py
Normal file
30
src/models/kohonen_network_sam.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import mlflow
|
||||||
|
from .common import initialise_random_tensor
|
||||||
|
|
||||||
|
def train_kohonen_network_sam(input_data, n_max_iterations, width, height, feature_size,
|
||||||
|
initial_learning_rate=0.1, use_mlflow=False):
|
||||||
|
if use_mlflow:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
σ0 = max(width, height) / 2
|
||||||
|
α0 = initial_learning_rate
|
||||||
|
weights = initialise_random_tensor(width, height, feature_size)
|
||||||
|
λ = n_max_iterations / np.log(σ0)
|
||||||
|
for t in range(n_max_iterations):
|
||||||
|
σt = σ0 * np.exp(-t/λ)
|
||||||
|
αt = α0 * np.exp(-t/λ)
|
||||||
|
for vt in input_data:
|
||||||
|
bmu = np.argmin(np.sum((weights - vt) ** 2, axis=2))
|
||||||
|
bmu_x, bmu_y = np.unravel_index(bmu, (width, height))
|
||||||
|
for x in range(width):
|
||||||
|
for y in range(height):
|
||||||
|
di = np.sqrt(((x - bmu_x) ** 2) + ((y - bmu_y) ** 2))
|
||||||
|
θt = np.exp(-(di ** 2) / (2*(σt ** 2)))
|
||||||
|
weights[x, y] += αt * θt * (vt - weights[x, y])
|
||||||
|
|
||||||
|
if use_mlflow:
|
||||||
|
execution_time = time.perf_counter() - start_time
|
||||||
|
mlflow.log_metric("execution_time", execution_time)
|
||||||
|
return weights
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
27
tests/models/common.py
Normal file
27
tests/models/common.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
import numpy as np
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
MINIMUM_NETWORK_DIMENSION = 3
|
||||||
|
|
||||||
|
def generate_kohonen_samples(num_features):
|
||||||
|
"""Custom Hypothesis strategy to generate matrix X with consistent feature length."""
|
||||||
|
return st.lists(
|
||||||
|
st.lists(
|
||||||
|
st.floats(allow_nan=False, allow_infinity=False, min_value=-1e6, max_value=1e6),
|
||||||
|
min_size=num_features, max_size=num_features # Consistent feature length
|
||||||
|
),
|
||||||
|
min_size=1, max_size=100
|
||||||
|
).map(lambda x: np.array(x, dtype=np.float32))
|
||||||
|
|
||||||
|
def generate_kohonen_weights(width, height, feature_size):
|
||||||
|
"""Custom Hypothesis strategy to generate a weights matrix with specified dimensions and feature size."""
|
||||||
|
return st.lists(
|
||||||
|
st.lists(
|
||||||
|
st.lists(
|
||||||
|
st.floats(min_value=-10, max_value=10),
|
||||||
|
min_size=feature_size, max_size=feature_size # Each feature vector has a consistent feature size
|
||||||
|
),
|
||||||
|
min_size=width, max_size=width # Consistent width for each row
|
||||||
|
),
|
||||||
|
min_size=height, max_size=height # Consistent height for the matrix
|
||||||
|
).map(lambda x: np.array(x, dtype=np.float32))
|
13
tests/models/test_common.py
Normal file
13
tests/models/test_common.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
from hypothesis import given, strategies as st
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from models.common import initialise_random_tensor
|
||||||
|
|
||||||
|
@given(
|
||||||
|
dims=st.tuples(st.integers(min_value=1, max_value=10)).map(tuple)
|
||||||
|
)
|
||||||
|
def test_initialise_random_tensor_dtype_and_shape(dims):
|
||||||
|
tensor = initialise_random_tensor(*dims)
|
||||||
|
assert tensor.dtype == np.float32, "The dtype of the tensor should be float32"
|
||||||
|
assert tensor.shape == dims, f"Expected tensor shape {dims}, but got {tensor.shape}"
|
36
tests/models/test_create_kohonen_params.py
Normal file
36
tests/models/test_create_kohonen_params.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from hypothesis import given, strategies as st
|
||||||
|
import numpy as np
|
||||||
|
from returns.result import Success, Failure
|
||||||
|
|
||||||
|
from models import create_kohonen_params
|
||||||
|
from .common import generate_kohonen_samples, MINIMUM_NETWORK_DIMENSION
|
||||||
|
|
||||||
|
@given(
|
||||||
|
data=st.data(),
|
||||||
|
feature_size=st.integers(min_value=1, max_value=10),
|
||||||
|
width=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=1000),
|
||||||
|
height=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=1000),
|
||||||
|
num_iterations=st.integers(min_value=1, max_value=10000),
|
||||||
|
initial_learning_rate=st.floats(min_value=1e-5, max_value=1.0, allow_nan=False, allow_infinity=False)
|
||||||
|
)
|
||||||
|
def test_create_kohonen_params_valid_success(data, feature_size, width, height, num_iterations, initial_learning_rate):
|
||||||
|
X = data.draw(generate_kohonen_samples(feature_size))
|
||||||
|
|
||||||
|
result = create_kohonen_params(X, width, height, num_iterations, initial_learning_rate)
|
||||||
|
assert isinstance(result, Success)
|
||||||
|
|
||||||
|
@given(
|
||||||
|
X=st.one_of(
|
||||||
|
st.lists(st.floats(allow_nan=False, allow_infinity=False, min_value=-1e6, max_value=1e6)), # Not 2D
|
||||||
|
st.just(np.array([])) # Empty
|
||||||
|
),
|
||||||
|
width=st.integers(max_value=MINIMUM_NETWORK_DIMENSION - 1), # Includes non-positive
|
||||||
|
height=st.integers(max_value=MINIMUM_NETWORK_DIMENSION - 1), # Includes non-positive
|
||||||
|
num_iterations=st.integers(max_value=0), # Includes non-positive
|
||||||
|
initial_learning_rate=st.one_of(
|
||||||
|
st.floats(max_value=0.0, allow_nan=False, allow_infinity=False), # Includes non-positive and NaN/inf
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def test_create_kohonen_params_invalid_failure(X, width, height, num_iterations, initial_learning_rate):
|
||||||
|
result = create_kohonen_params(X, width, height, num_iterations, initial_learning_rate)
|
||||||
|
assert isinstance(result, Failure) and isinstance(result.failure(), ValueError)
|
146
tests/models/test_kohonen_network.py
Normal file
146
tests/models/test_kohonen_network.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
import numpy as np
|
||||||
|
from hypothesis import given, settings, HealthCheck, strategies as st
|
||||||
|
from returns.result import Success
|
||||||
|
|
||||||
|
from models.kohonen_network import Node, KohonenParams, train_kohonen_network, _find_best_matching_unit, _calculate_neighbourhood_radius, _calculate_learning_rate, _update_weights
|
||||||
|
from .common import generate_kohonen_weights, generate_kohonen_samples, MINIMUM_NETWORK_DIMENSION
|
||||||
|
|
||||||
|
@given(
|
||||||
|
data=st.data(),
|
||||||
|
width=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=100),
|
||||||
|
height=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=100),
|
||||||
|
feature_size=st.integers(min_value=1, max_value=10),
|
||||||
|
)
|
||||||
|
@settings(max_examples=5, deadline=None, suppress_health_check=(HealthCheck.too_slow,))
|
||||||
|
def test_find_best_matching_unit_success(data, width, height, feature_size):
|
||||||
|
weights = data.draw(generate_kohonen_weights(width, height, feature_size))
|
||||||
|
target_index = data.draw(st.tuples(st.integers(min_value=0, max_value=height-1),
|
||||||
|
st.integers(min_value=0, max_value=width-1)))
|
||||||
|
|
||||||
|
x = weights[target_index[0]][target_index[1]]
|
||||||
|
|
||||||
|
# Make a modification so `x` is exactly one of the weights
|
||||||
|
weights = np.array(weights)
|
||||||
|
x = weights[target_index[0], target_index[1], :].copy()
|
||||||
|
|
||||||
|
bmu = _find_best_matching_unit(weights, x, width, height)
|
||||||
|
|
||||||
|
# Check if the returned node is correct
|
||||||
|
assert bmu == Node(target_index[0], target_index[1]), f"Expected Node({target_index[0]}, {target_index[1]}), got {bmu}"
|
||||||
|
|
||||||
|
def test_neighbourhood_radius_strictly_decreasing():
|
||||||
|
params = KohonenParams(
|
||||||
|
initial_neighbourhood_radius=5.0,
|
||||||
|
initial_learning_rate=0.1,
|
||||||
|
exp_time_const=20.0, # Affects rate of radius decrease
|
||||||
|
width=10,
|
||||||
|
height=10,
|
||||||
|
num_iterations=100,
|
||||||
|
feature_size=3
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_radius = _calculate_neighbourhood_radius(0, params)
|
||||||
|
strictly_decreasing = True
|
||||||
|
|
||||||
|
# Check the radius for the first 50 iterations to ensure it's strictly decreasing
|
||||||
|
for iteration in range(1, 50):
|
||||||
|
current_radius = _calculate_neighbourhood_radius(iteration, params)
|
||||||
|
if current_radius >= previous_radius:
|
||||||
|
strictly_decreasing = False
|
||||||
|
break
|
||||||
|
previous_radius = current_radius
|
||||||
|
|
||||||
|
assert strictly_decreasing, "The neighbourhood radius is not strictly decreasing over iterations."
|
||||||
|
|
||||||
|
def test_learning_rate_strictly_decreasing():
|
||||||
|
params = KohonenParams(
|
||||||
|
initial_neighbourhood_radius=5.0,
|
||||||
|
initial_learning_rate=0.1,
|
||||||
|
exp_time_const=20.0, # Affects rate of learning rate decrease
|
||||||
|
width=10,
|
||||||
|
height=10,
|
||||||
|
num_iterations=100,
|
||||||
|
feature_size=3
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_rate = _calculate_learning_rate(0, params)
|
||||||
|
strictly_decreasing = True
|
||||||
|
|
||||||
|
# Check the learning rate for the first 50 iterations to ensure it's strictly decreasing
|
||||||
|
for iteration in range(1, 50):
|
||||||
|
current_rate = _calculate_learning_rate(iteration, params)
|
||||||
|
if current_rate >= previous_rate:
|
||||||
|
strictly_decreasing = False
|
||||||
|
break
|
||||||
|
previous_rate = current_rate
|
||||||
|
|
||||||
|
assert strictly_decreasing, "The learning rate is not strictly decreasing as expected."
|
||||||
|
|
||||||
|
@given(
|
||||||
|
data=st.data(),
|
||||||
|
feature_size=st.integers(min_value=1, max_value=3),
|
||||||
|
width=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=50),
|
||||||
|
height=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=50),
|
||||||
|
neighbourhood_radius=st.floats(min_value=0.01, max_value=1.0),
|
||||||
|
learning_rate=st.floats(min_value=0.001, max_value=0.01)
|
||||||
|
)
|
||||||
|
@settings(max_examples=5, deadline=None, suppress_health_check=(HealthCheck.too_slow,))
|
||||||
|
def test_update_weights_properties(data, feature_size, width, height, neighbourhood_radius, learning_rate):
|
||||||
|
X = data.draw(generate_kohonen_samples(feature_size))
|
||||||
|
weights = data.draw(generate_kohonen_weights(width, height, feature_size))
|
||||||
|
|
||||||
|
params = KohonenParams(
|
||||||
|
initial_neighbourhood_radius=5.0,
|
||||||
|
initial_learning_rate=0.1,
|
||||||
|
exp_time_const=20.0,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_iterations=10,
|
||||||
|
feature_size=feature_size
|
||||||
|
)
|
||||||
|
|
||||||
|
x = X[0]
|
||||||
|
updated_weights = _update_weights(x, weights, neighbourhood_radius, learning_rate, params)
|
||||||
|
|
||||||
|
if np.all(x == 0) and np.all(weights == 0):
|
||||||
|
assert np.all(updated_weights == 0), "Weights should be zero when x and initial weights are zero."
|
||||||
|
else:
|
||||||
|
assert not np.array_equal(weights, updated_weights), "Weights should have changed after update"
|
||||||
|
|
||||||
|
# Ensure all weights are updated towards the input vector x
|
||||||
|
original_distances = np.linalg.norm(weights - x, axis=2)
|
||||||
|
new_distances = np.linalg.norm(updated_weights - x, axis=2)
|
||||||
|
assert np.all(new_distances <= original_distances), "All weights should move closer to input vector x"
|
||||||
|
|
||||||
|
@given(
|
||||||
|
data=st.data(),
|
||||||
|
initial_neighbourhood_radius=st.floats(min_value=1, max_value=10),
|
||||||
|
initial_learning_rate=st.floats(min_value=0.01, max_value=10),
|
||||||
|
exp_time_const=st.floats(min_value=10, max_value=100),
|
||||||
|
width=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=50),
|
||||||
|
height=st.integers(min_value=MINIMUM_NETWORK_DIMENSION, max_value=50),
|
||||||
|
num_iterations=st.integers(min_value=0, max_value=100),
|
||||||
|
feature_size=st.integers(min_value=1, max_value=3)
|
||||||
|
)
|
||||||
|
@settings(max_examples=5, deadline=None, suppress_health_check=(HealthCheck.too_slow,))
|
||||||
|
def test_train_kohonen_network_valid_success(data, initial_neighbourhood_radius, initial_learning_rate, exp_time_const, width, height, num_iterations, feature_size):
|
||||||
|
X = data.draw(generate_kohonen_samples(feature_size))
|
||||||
|
|
||||||
|
params = KohonenParams(
|
||||||
|
initial_neighbourhood_radius=initial_neighbourhood_radius,
|
||||||
|
initial_learning_rate=initial_learning_rate,
|
||||||
|
exp_time_const=exp_time_const,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_iterations=num_iterations,
|
||||||
|
feature_size=feature_size
|
||||||
|
)
|
||||||
|
|
||||||
|
result = train_kohonen_network(X, params, use_mlflow=False)
|
||||||
|
|
||||||
|
# Assert that the result is a Success and the data type and shape are correct
|
||||||
|
assert isinstance(result, Success), "Expected the result to be a Success instance"
|
||||||
|
weights = result.unwrap()
|
||||||
|
assert isinstance(weights, np.ndarray), "Expected the result content to be a numpy array"
|
||||||
|
assert weights.dtype == np.float32, "Expected the numpy array to be of type float32"
|
||||||
|
assert weights.shape == (height, width, feature_size), "Unexpected shape of result weights"
|
Loading…
Reference in New Issue
Block a user