13 lines
505 B
Python
13 lines
505 B
Python
from hypothesis import given, strategies as st
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from models.common import initialise_random_tensor
|
|
|
|
@given(
|
|
dims=st.tuples(st.integers(min_value=1, max_value=10)).map(tuple)
|
|
)
|
|
def test_initialise_random_tensor_dtype_and_shape(dims):
|
|
tensor = initialise_random_tensor(*dims)
|
|
assert tensor.dtype == np.float32, "The dtype of the tensor should be float32"
|
|
assert tensor.shape == dims, f"Expected tensor shape {dims}, but got {tensor.shape}" |