Source code for atomworks.ml.utils.numpy

"""General utility functions for working with numpy arrays."""

import networkx as nx
import numpy as np


[docs] def select_data_by_id( select_ids: np.ndarray, data_ids: np.ndarray, data: np.ndarray, axis: int = 0, ) -> np.ndarray: """ Select data from an array based on matching IDs. Args: select_ids (np.ndarray): Array of IDs to select. data_ids (np.ndarray): Array of IDs corresponding to the data. data (np.ndarray): Data array from which to select. axis (int, optional): Axis along which to select data. Defaults to 0. Returns: np.ndarray: Array of selected data. Raises: AssertionError: If the shape of `data` along `axis` does not match the length of `data_ids`. AssertionError: If `data_ids` contains duplicate values. Example: >>> to_ids = np.array([1, 5, 2, 20, 20, 2]) >>> from_array = np.arange(10).repeat(6).reshape(10, 6) >>> from_ids = np.array([1, 2, 5, 6, 7, 21, 22, 23, 20, 25]) >>> select_data_by_id(to_ids, from_ids, from_array) array([[0., 0., 0., 0., 0., 0.], [2., 2., 2., 2., 2., 2.], [1., 1., 1., 1., 1., 1.], [8., 8., 8., 8., 8., 8.], [8., 8., 8., 8., 8., 8.], [1., 1., 1., 1., 1., 1.]]) """ assert data.shape[axis] == len( data_ids ), f"`data` must have `len(data_ids)` along axis `{axis}`, but got shape: {data.shape}" assert np.unique(data_ids).size == len( data_ids ), f"`data_ids` must be unique. Got duplicates ({np.unique(data_ids)}) in {data_ids}" id_to_idx = np.vectorize({id_: idx for idx, id_ in enumerate(data_ids)}.__getitem__, otypes=[data_ids.dtype]) idxs_to_select = id_to_idx(select_ids) return np.take(data, idxs_to_select, axis=axis)
[docs] def insert_data_by_id_( to_fill: np.ndarray, to_fill_ids: np.ndarray, from_data: np.ndarray, from_data_ids: np.ndarray, axis: int = 0, ) -> np.ndarray: """Insert data into an array based on matching IDs. Args: to_fill (np.ndarray): Array to be filled. to_fill_ids (np.ndarray): Array of IDs corresponding to `to_fill`. from_data (np.ndarray): Data array from which to insert. from_data_ids (np.ndarray): Array of IDs corresponding to `from_data`. axis (int, optional): Axis along which to insert data. Defaults to 0. Returns: np.ndarray: Array with inserted data. Example: >>> to_array = np.zeros((7, 6)) >>> to_ids = np.array([1, 5, 2, 17, 20, 20, 2]) >>> from_array = np.arange(10).repeat(6).reshape(10, 6) >>> from_ids = np.array([1, 2, 5, 6, 7, 21, 22, 23, 20, 25]) >>> insert_data_by_id_(to_array, to_ids, from_array, from_ids) >>> print(to_array) array([[0., 0., 0., 0., 0., 0.], [2., 2., 2., 2., 2., 2.], [1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0.], [8., 8., 8., 8., 8., 8.], [8., 8., 8., 8., 8., 8.], [1., 1., 1., 1., 1., 1.]]) """ is_available = np.isin(to_fill_ids, from_data_ids) to_fill[is_available] = select_data_by_id(to_fill_ids[is_available], from_data_ids, from_data, axis=axis)
[docs] def unique_by_first_occurrence(arr: np.ndarray) -> np.ndarray: """ Return unique elements of an array while preserving the order of their first occurrence. Args: arr (np.ndarray): Input array. Returns: np.ndarray: Array of unique elements in the order of their first occurrence. Example: >>> unique_by_first_occurrence(np.array([4, 2, 2, 3, 1, 4])) array([4, 2, 3, 1]) """ _, idx = np.unique(arr, return_index=True) return arr[np.sort(idx, kind="mergesort")]
[docs] def is_mask_contiguous(mask: np.ndarray) -> bool: """Check if a mask is contiguous.""" if np.all(mask == mask[0]): return True mask_with_two_falses = np.concatenate([[False], mask, [False]]).astype(np.int8) diffs = np.diff(mask_with_two_falses) vals, counts = np.unique(diffs, return_counts=True) return np.array_equal(vals, [-1, 0, 1]) and counts[0] == 1 and counts[2] == 1
[docs] def get_connected_components_from_adjacency(adjacency: np.ndarray) -> list[np.ndarray]: """ Return a list of indices for each connected component according to the given adjacency matrix. """ graph = nx.from_numpy_array(adjacency) return [np.array(list(component)) for component in nx.connected_components(graph)]
[docs] def not_isin(element: np.ndarray, test_element: np.ndarray) -> np.ndarray: """ Return a boolean mask indicating where elements of `element` are not in `test_element`. Args: element (np.ndarray): Array to check. test_element (np.ndarray): Array to check against. Returns: np.ndarray: Boolean mask. Example: >>> not_isin(np.array([1, 2, 3, 4, 5]), np.array([2, 4, 6])) array([ True, False, True, False, True]) """ return np.isin(element, test_element, invert=True)
[docs] def get_nearest_true_index_for_each_false(arr: np.ndarray) -> np.ndarray: """ Get the index of the nearest True for each False in the array, breaking ties by choosing the nearest True to the left. Args: - arr (np.ndarray): A boolean numpy array. Returns: - np.ndarray: An array of length `np.sum(~arr)` where each entry is the index of the nearest True. Example: >>> arr = np.array([False, True, True, False, False, True, False]) >>> get_nearest_true_index_for_each_false(arr) array([1, 2, 5, 5]) """ # ...find the indices where the values are True and False true_indices = np.where(arr)[0] false_indices = np.where(~arr)[0] # Short-circuit if there are no True entries or no False entries, as we can't proceed if len(true_indices) == 0 or len(false_indices) == 0: return np.array([]) # ...for False entries, find the index of the nearest True # Calculate distances to the nearest True indices # Using broadcasting to calculate the distance matrix (e.g., outer difference) # i,j entry of the distance matrix is the distance between the i-th False and j-th True distances = np.abs(false_indices[:, np.newaxis] - true_indices) # Use argmin to find the index of the minimum distance # np.argmin will automatically break ties by choosing the first occurrence nearest_true_indices = true_indices[np.argmin(distances, axis=1)] return nearest_true_indices
[docs] def get_indices_of_non_constant_columns(arr: np.ndarray) -> np.ndarray: """Identify columns where values change between consecutive rows. Args: arr (np.ndarray): A 2D NumPy array where you want to find columns with changing values. Returns: np.ndarray: An array of column indices where values change between consecutive rows. Example: >>> arr = np.array( ... [ ... [151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161], ... [151, 152, 153, 154, 155, 156, 157, 158, 159, 161, 160], ... ] ... ) >>> find_changing_columns(arr) array([ 9, 10]) """ # Compute the differences between consecutive rows differences = np.diff(arr, axis=0) # Get the indices where the differences are non-zero changing_indices = np.nonzero(differences) # Return the column indices where changes occur return changing_indices[1]