Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
163 commits
Select commit Hold shift + click to select a range
21b1a72
switch to lax import
rhayes777 Feb 20, 2023
4bb32a0
started moving towards arrays as wrappers rather than inheritance
rhayes777 Feb 20, 2023
6899393
more conversion
rhayes777 Feb 20, 2023
90fb220
more conversion
rhayes777 Feb 20, 2023
c18b209
more conversion...
rhayes777 Feb 20, 2023
3fc02f1
Merge branch 'main' into feature/jax
rhayes777 Feb 27, 2023
1e62bc7
new to init
rhayes777 Feb 27, 2023
8732366
implementing comparison operators
rhayes777 Feb 27, 2023
776c4d1
remove need for reduce
rhayes777 Feb 27, 2023
efeaac1
mul and rmul
rhayes777 Feb 27, 2023
048f2f7
array property
rhayes777 Feb 27, 2023
5078d58
explicity pass array reference when plotting
rhayes777 Feb 27, 2023
e874d22
irregular 2d conversion
rhayes777 Feb 27, 2023
2af77ea
float and array property
rhayes777 Feb 27, 2023
7c40da6
div
rhayes777 Feb 27, 2023
ac7b9e6
unwrap array
rhayes777 Feb 27, 2023
bb22438
unwrap arrays
rhayes777 Feb 27, 2023
7c10dcf
to new array
rhayes777 Feb 27, 2023
929428c
abs
rhayes777 Feb 27, 2023
7415559
write array
rhayes777 Feb 27, 2023
175b320
set item...
rhayes777 Feb 27, 2023
4fe0aa3
cast multiplied arrays to new array; now at a new jax error
rhayes777 Feb 27, 2023
ce9e0c8
unwrap when multiplying
rhayes777 Feb 27, 2023
8b3c5f2
fix is_all_false
rhayes777 Feb 27, 2023
f482b19
a
rhayes777 Feb 27, 2023
c4c5366
add and sub
rhayes777 Feb 27, 2023
be4939e
pow
rhayes777 Feb 27, 2023
41c7450
use - instead of numpy for now
rhayes777 Feb 27, 2023
ebc4cd1
cast to iterable
rhayes777 Feb 27, 2023
2c7cfdd
sqrt
rhayes777 Feb 27, 2023
a0e59c5
remove numba import assertion
rhayes777 Feb 27, 2023
30a6ebc
dtype, ndim, max, min
rhayes777 Feb 27, 2023
2b5705a
array representation
rhayes777 Feb 27, 2023
1393c75
__array__
rhayes777 Mar 6, 2023
6ff48b5
ignore files/
rhayes777 Mar 6, 2023
3ca7987
mask 1d new -> init
rhayes777 Mar 6, 2023
897586e
array 1d new init conversion
rhayes777 Mar 6, 2023
d474474
replacing numpy calls with operators to preserve type
rhayes777 Mar 6, 2023
8e0b836
converting more functions...
rhayes777 Mar 6, 2023
86b4d1b
breaking up test and using allclose
rhayes777 Mar 6, 2023
044b3e3
fixing more preprocessing functions...
rhayes777 Mar 6, 2023
bd0df13
new -> init for visibilities
rhayes777 Mar 6, 2023
a8e772a
fixed visibilities init
rhayes777 Mar 6, 2023
328d765
more moving away from np functions
rhayes777 Mar 6, 2023
ae4701c
astype
rhayes777 Mar 6, 2023
6f0f490
grid1d new -> init
rhayes777 Mar 6, 2023
2dfa95e
grid2d iterate new -> init
rhayes777 Mar 6, 2023
f72b9cc
import and format
rhayes777 Mar 6, 2023
0f3d8ce
more new->init
rhayes777 Mar 6, 2023
0b9d2db
real and imag
rhayes777 Mar 6, 2023
1b17c4e
all
rhayes777 Mar 6, 2023
f8e5c15
converting fit util methods. Some used out= but motivation is unclear
rhayes777 Mar 6, 2023
626c453
added with_new_array method and used for two fit fieldds
rhayes777 Mar 13, 2023
3c8a61c
optionally pass dtype when calling abstract nd array's array method
rhayes777 Mar 13, 2023
e76ce36
fix return type
rhayes777 Mar 13, 2023
ed08121
all close
rhayes777 Mar 13, 2023
ffff8ca
to_new_array decorator for remaining fit_util functions
rhayes777 Mar 13, 2023
5df0814
new->init for grid2dsparse
rhayes777 Mar 13, 2023
1bbfa8d
commented out previously broken test
rhayes777 Mar 13, 2023
4b64f88
unwrap other arrays in constructor
rhayes777 Mar 13, 2023
1336e6a
use an allclose
rhayes777 Mar 13, 2023
11d084b
broader type check
rhayes777 Mar 13, 2023
edfb3d1
copy
rhayes777 Mar 13, 2023
0f62afd
use normal numpy for masks (for now)
rhayes777 Mar 13, 2023
c6b63ba
more new to init
rhayes777 Mar 13, 2023
71bc453
more new to init
rhayes777 Mar 13, 2023
8a8e3ee
init to new and generalised type check
rhayes777 Mar 13, 2023
19af0d2
flipped
rhayes777 Mar 13, 2023
9ef41d4
new to init
rhayes777 Mar 13, 2023
2360f5c
new->init all tests pass
rhayes777 Mar 13, 2023
82f6293
convert properties to avoid using masked array as it is not implement…
rhayes777 Mar 13, 2023
607963a
simplification
rhayes777 Mar 13, 2023
c59302c
avoid assignment
rhayes777 Mar 13, 2023
8925aa3
jax does not like casting to float
rhayes777 Mar 13, 2023
9d7f58c
added jax to requirements
rhayes777 Mar 13, 2023
41c6fc7
added jaxlib to requirements
rhayes777 Mar 13, 2023
bb4a667
swap around requirements
rhayes777 Mar 13, 2023
227a48f
fixed scikit image
rhayes777 Mar 13, 2023
b461575
make assertion approx
rhayes777 Mar 13, 2023
f91ea09
Merge branch 'main' into feature/jax
rhayes777 Mar 20, 2023
4e288f1
removed single explicit dependency on jax numpy
rhayes777 Mar 24, 2023
4217b0e
use numpy wrapper for specific function
rhayes777 Mar 24, 2023
a5f970e
use setitem from jax branch
rhayes777 Mar 24, 2023
3d3d646
use numpy wrapper for noise_normalization_from function
rhayes777 Mar 24, 2023
9eee5b3
pytree functions from jax branch
rhayes777 Mar 24, 2023
bfc1be2
constructor from jax branch including implicit pytree registration
rhayes777 Mar 24, 2023
7038757
pytree DeriveIndexes2D
rhayes777 Mar 24, 2023
5b7281c
use `__no_flatten__` class attribute to prevent cyclic references whe…
rhayes777 Mar 24, 2023
2b6500e
try-except on shape; access array property to ensure return type is a…
rhayes777 Mar 24, 2023
90f136d
merge
rhayes777 Mar 24, 2023
6b71ed3
dropped calling .array on likelihood
rhayes777 Apr 17, 2023
9752993
Revert "dropped calling .array on likelihood"
rhayes777 Apr 24, 2023
4f4f90c
switching jax on and off
rhayes777 Apr 24, 2023
7da8d07
merged main
rhayes777 Dec 11, 2023
ab2ed69
numba util
rhayes777 Dec 11, 2023
07b34a3
merge
rhayes777 Dec 11, 2023
46efc92
attempting to fix tests...
rhayes777 Dec 11, 2023
ec30292
took abstract array from jax branch
rhayes777 Dec 11, 2023
91fc6e2
casting masks to arrays for use in numba
rhayes777 Dec 11, 2023
c64010a
more casting arrays
rhayes777 Dec 11, 2023
3a47c04
more casting to array
rhayes777 Dec 11, 2023
78de17f
more casting to array
rhayes777 Dec 11, 2023
80c1c70
consistent attribute name usage in decorator
rhayes777 Dec 11, 2023
3d6c1aa
more casting to array
rhayes777 Dec 11, 2023
8e3b781
more casting to array
rhayes777 Dec 18, 2023
c81c1e1
more casting to array
rhayes777 Dec 18, 2023
43a5f3f
more casting to array
rhayes777 Dec 18, 2023
64a5890
more casting to array
rhayes777 Dec 18, 2023
2fb990b
more casting to array
rhayes777 Dec 18, 2023
ee292df
more casting to array
rhayes777 Dec 18, 2023
c5463f4
more casting to array
rhayes777 Dec 18, 2023
8648a0e
fix a broken reference in init
rhayes777 Dec 18, 2023
a6b32ce
more casting to array
rhayes777 Dec 18, 2023
0bd4ce8
more casting to array
rhayes777 Dec 18, 2023
9c55d5c
more casting to array
rhayes777 Dec 18, 2023
ce054ed
more casting to array
rhayes777 Dec 18, 2023
b006743
lift flip_hdu_for_ds9 method from main branch
rhayes777 Dec 18, 2023
24d145a
more casting to array
rhayes777 Dec 18, 2023
ba599f1
more casting to array
rhayes777 Dec 18, 2023
b79ae22
more casting to array
rhayes777 Dec 18, 2023
1fc20ae
more casting to array
rhayes777 Dec 18, 2023
a3d89e9
more casting to array
rhayes777 Dec 18, 2023
05ae5f0
more casting to array
rhayes777 Dec 18, 2023
e95916b
more casting to array
rhayes777 Dec 18, 2023
120072f
more casting to array
rhayes777 Dec 18, 2023
563f89c
cast to array
rhayes777 Dec 18, 2023
8e8abc0
cast to array
rhayes777 Dec 18, 2023
ddacab7
remove sparse_index_for_slim_index from Grid2DSparse
rhayes777 Dec 18, 2023
4b85374
cast to array
rhayes777 Dec 18, 2023
cc302be
cast to array
rhayes777 Dec 18, 2023
e8c97f3
cast to array
rhayes777 Dec 18, 2023
4ed9b8c
cast to array
rhayes777 Dec 18, 2023
07db622
fixed remaining tests
rhayes777 Dec 18, 2023
4cd3ace
merge main
rhayes777 Dec 18, 2023
cc989b4
ensure tests pass without jax being installed
rhayes777 Dec 18, 2023
a1669cf
Grid2DTransformedNumpy as AbstractNDArray to fix autogalaxy
rhayes777 Dec 18, 2023
7cb6cc1
a slice of an array of a given type is still of that same type
rhayes777 Dec 18, 2023
912874e
by default get attributes from the underlying array
rhayes777 Dec 18, 2023
881edc2
fix autogalaxy test by casting transformed grid back to grid
rhayes777 Dec 18, 2023
b2e18f0
more casting to array
rhayes777 Dec 18, 2023
dcbb09b
fix noise map function
rhayes777 Dec 18, 2023
3b4e467
fix
rhayes777 Dec 18, 2023
481c70d
merge main
rhayes777 Jan 8, 2024
6b085df
fixes
rhayes777 Jan 8, 2024
2578532
changes to support autolens...
rhayes777 Jan 8, 2024
fc7047a
more casting to array so that indexed/sliced arrays have numba compli…
rhayes777 Jan 8, 2024
4642432
more fixes
rhayes777 Jan 8, 2024
6137d03
fix test
rhayes777 Jan 8, 2024
8c846c9
ensure underlying array gets copied
rhayes777 Jan 8, 2024
054567c
ensure underlying array is always copied
rhayes777 Jan 8, 2024
8ca6a89
invert and copy
rhayes777 Jan 8, 2024
1e24c9b
ensure copying works
rhayes777 Jan 17, 2024
45dcb52
in place multiply
rhayes777 Jan 17, 2024
cef9a68
deepcopy
rhayes777 Jan 17, 2024
1cf6766
change where array conversion occurs
rhayes777 Jan 17, 2024
ee057f7
debugging - revert
rhayes777 Jan 22, 2024
f5ed301
do not cache derive_indexes property
rhayes777 Jan 22, 2024
1173f69
Revert "debugging - revert"
rhayes777 Jan 22, 2024
deddad3
Merge branch 'main' into feature/jax_merge
rhayes777 Jan 22, 2024
a23e419
fix issue created by merge
rhayes777 Jan 22, 2024
760dd91
docs
rhayes777 Jan 22, 2024
6de2446
Merge branch 'main' into feature/jax_merge
rhayes777 Jan 29, 2024
ae7e488
fixed numba calls
rhayes777 Jan 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
files/
test_autoarray/dataset/files/array/output_test/uv_wavelengths.fits
test_autoarray/dataset/files/array/output_test/visibilities.fits
test_autoarray/dataset/plot/files/
Expand Down
348 changes: 326 additions & 22 deletions autoarray/abstract_ndarray.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,272 @@
from __future__ import annotations

from copy import copy

from abc import ABC
from abc import abstractmethod
import numpy as np
from pathlib import Path
from typing import Union, TYPE_CHECKING

from autoconf import conf
from autoarray.numpy_wrapper import numpy as npw, register_pytree_node, Array

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from autoarray.structures.abstract_structure import Structure

from autoarray.structures.arrays import array_2d_util
from autoconf import conf


class AbstractNDArray(np.ndarray, ABC):
def __reduce__(self):
pickled_state = super().__reduce__()
def to_new_array(func):
"""
Decorator for functions that returns an array. The array is wrapped in a new instance of the class.

class_dict = {}
for key, value in self.__dict__.items():
class_dict[key] = value
new_state = pickled_state[2] + (class_dict,)
Parameters
----------
func
The function to be decorated.

return pickled_state[0], pickled_state[1], new_state
Returns
-------
The decorated function.
"""

# noinspection PyMethodOverriding
def __setstate__(self, state):
for key, value in state[-1].items():
setattr(self, key, value)
super().__setstate__(state[0:-1])
def wrapper(self, *args, **kwargs) -> "AbstractNDArray":
return self.with_new_array(func(self, *args, **kwargs))

@property
@abstractmethod
def native(self) -> Structure:
return wrapper


def unwrap_array(func):
"""
Decorator for functions that take an array as an argument. If the argument is an AbstractNDArray, the underlying
array is used instead.

Parameters
----------
func
The function to be decorated.

Returns
-------
The decorated function.
"""

def wrapper(self, other):
try:
return func(self, other.array)
except AttributeError:
return func(self, other)

return wrapper


class AbstractNDArray(ABC):
def __init__(self, array):
while isinstance(array, AbstractNDArray):
array = array.array
self._array = array
try:
register_pytree_node(
type(self),
self.instance_flatten,
self.instance_unflatten,
)
except ValueError:
pass

__no_flatten__ = ()

def invert(self):
new = self.copy()
new._array = np.invert(new._array)
return new

@classmethod
def instance_flatten(cls, instance):
"""
Returns the data structure in its `native` format which contains all unmaksed values to the native dimensions.
Flatten an instance of an autoarray class into a tuple of its attributes (i.e.. a pytree)
"""
keys, values = zip(
*sorted(
{
key: value
for key, value in instance.__dict__.items()
if key not in cls.__no_flatten__
}.items()
)
)
return values, keys

@staticmethod
def flip_hdu_for_ds9(values):
if conf.instance["general"]["fits"]["flip_for_ds9"]:
return np.flipud(values)
return values

def output_to_fits(self, file_path: Union[Path, str], overwrite: bool = False):
@classmethod
def instance_unflatten(cls, aux_data, children):
"""
Unflatten a tuple of attributes (i.e. a pytree) into an instance of an autoarray class
"""
instance = cls.__new__(cls)
for key, value in zip(aux_data, children[1:]):
setattr(instance, key, value)
return instance

def with_new_array(self, array: np.ndarray) -> "AbstractNDArray":
"""
Copy this object but give it a new array.

This is used to ensure that when an array is modified, associated
attributes such as pixel size are retained.

Parameters
----------
array
The new array that is given to the copied object.

Returns
-------

"""
new_array = self.copy()
new_array._array = array
return new_array

def copy(self):
new = copy(self)
return new

def __copy__(self):
"""
When copying an autoarray also copy its underlying array.
"""
new = self.__new__(self.__class__)
new.__dict__.update(self.__dict__)
new._array = self._array.copy()
return new

def __deepcopy__(self, memo):
"""
When copying an autoarray also copy its underlying array.
"""
new = self.__new__(self.__class__)
new.__dict__.update(self.__dict__)
new._array = self._array.copy()
return new

def __iter__(self):
return iter(self._array)

@to_new_array
def sqrt(self):
return np.sqrt(self._array)

@property
def array(self):
return self._array

@unwrap_array
def __lt__(self, other):
return self._array < other

@unwrap_array
def __le__(self, other):
return self._array <= other

@unwrap_array
def __gt__(self, other):
return self._array > other

@unwrap_array
def __ge__(self, other):
return self._array >= other

@unwrap_array
def __eq__(self, other):
return self._array == other

@to_new_array
@unwrap_array
def __pow__(self, other):
return self._array**other

@to_new_array
@unwrap_array
def __add__(self, other):
return self._array + other

@to_new_array
@unwrap_array
def __radd__(self, other):
return other + self._array

@to_new_array
@unwrap_array
def __sub__(self, other):
return self._array - other

@to_new_array
@unwrap_array
def __rsub__(self, other):
return other - self._array

@unwrap_array
def __ne__(self, other):
return self._array != other

@to_new_array
@unwrap_array
def __mul__(self, other):
return self._array * other

@to_new_array
@unwrap_array
def __rmul__(self, other):
return other * self._array

@to_new_array
def __neg__(self):
return -self._array

def __invert__(self):
return ~self._array

def __divmod__(self, other):
return divmod(self._array, other)

def __rdivmod__(self, other):
return divmod(other, self._array)

@to_new_array
@unwrap_array
def __truediv__(self, other):
return self._array / other

@to_new_array
@unwrap_array
def __rtruediv__(self, other):
return other / self._array

@to_new_array
def __abs__(self):
return abs(self._array)

def sum(self, *args, **kwargs):
return self._array.sum(*args, **kwargs)

def __float__(self):
return float(self._array)

@property
@abstractmethod
def native(self) -> Structure:
"""
Returns the data structure in its `native` format which contains all unmaksed values to the native dimensions.
"""

def output_to_fits(self, file_path: str, overwrite: bool = False):
"""
Output the grid to a .fits file.

Expand All @@ -55,5 +278,86 @@ def output_to_fits(self, file_path: Union[Path, str], overwrite: bool = False):
If a file already exists at the path, if overwrite=True it is overwritten else an error is raised.
"""
array_2d_util.numpy_array_2d_to_fits(
array_2d=self.native, file_path=file_path, overwrite=overwrite
array_2d=self.native.array, file_path=file_path, overwrite=overwrite
)

@property
def shape(self):
try:
return self._array.shape
except AttributeError:
return ()

@property
def size(self):
return self._array.size

@property
def dtype(self):
return self._array.dtype

@property
def ndim(self):
return self._array.ndim

def max(self, *args, **kwargs):
return self._array.max(*args, **kwargs)

def min(self, *args, **kwargs):
return self._array.min(*args, **kwargs)

@to_new_array
def reshape(self, *args, **kwargs):
return self._array.reshape(*args, **kwargs)

def __getattr__(self, item):
if item != "__setstate__":
try:
return getattr(self._array, item)
except AttributeError:
pass
raise AttributeError(
f"{self.__class__.__name__} does not have attribute {item}"
)

def __getitem__(self, item):
result = self._array[item]
if isinstance(item, slice):
result = self.with_new_array(result)
if isinstance(result, np.ndarray):
result = self.with_new_array(result)
return result

def __setitem__(self, key, value):
if isinstance(key, (np.ndarray, AbstractNDArray, Array)):
self._array = npw.where(key, value, self._array)
else:
self._array[key] = value

def __repr__(self):
return f"{self.__class__.__name__} {self.shape}"

def __array__(self, dtype=None):
if dtype:
return self._array.astype(dtype)
return self._array

def __len__(self):
return len(self._array)

@to_new_array
def astype(self, dtype):
return self._array.astype(dtype)

@property
@to_new_array
def real(self):
return self._array.real

@property
@to_new_array
def imag(self):
return self._array.imag

def all(self):
return self._array.all()
Loading