3.4. Solver
FitSNAP uses a Solver
class which is a parent of all the different types of solvers, e.g. SVD and
ARD for linear regression, PYTORCH and JAX for neural networks, etc.
- class fitsnap3lib.solvers.solver.Solver(name, pt, config, linear=True)
This class declares the method to solve the machine learning problem, e.g. linear regression, nonlinear regression, etc.
- fit
Numpy array containing coefficients of fit.
- error_analysis(a=None, b=None, w=None, fs_dict=None)
If linear fit: extracts and stores fitting data, such as descriptor values, truths, and predictions, into a Pandas dataframe.
If nonlinear fit: evaluate NN on all configurations to get truth values for error calculation.
The optional arguments are for calculating errors on a given set of inputs. For linear models these inputs are A matrix, truth array, weights, and a fs dictionary which contains group info. Care must be taken to ensure that these data structures are already processed and lined up properly.
- Parameters:
a – Optional A matrix numpy array.
b – Optional truth matrix numpy array.
w – Optional weight matrix numpy array.
fs_dict – Optional fs dictionary from a fs.pt.fitsnap_dict
- perform_fit()
Base class function for performing a fit.
- prepare_data(a, b, w, fs_dict)
Prepare a, b, w data for fitting by applying weight arrays w to the a and b arrays.
- Parameters:
a (np.array) – design matrix
b (np.array) – truth array
w (np.array) – weight array
fs_dict (dict) – dictionary with Testing key of bools for which structures to test on.
- Returns:
design matrix and truth array multiplied by weights.
- Return type:
aw, bw (np.array)
Specific solvers are inherited from the base Solver
class.
3.4.1. SVD
This class is for performing SVD fits on linear systems.
- class fitsnap3lib.solvers.svd.SVD(name, pt, config)
- perform_fit(a=None, b=None, w=None, fs_dict=None, trainall=False)
Perform fit on a linear system. If no args are supplied, will use fitting data in pt.shared_arrays.
- Parameters:
a (np.array) – Optional “A” matrix.
b (np.array) – Optional Truth array.
w (np.array) – Optional Weight array.
fs_dict (dict) – Optional dictionary containing a Testing key of which A matrix rows should not be trained.
trainall (bool) – Optional boolean declaring whether to train on all samples in the A matrix.
The fit is stored as a member fs.solver.fit.
3.4.2. RIDGE
This class is for performing ridge regression fits on linear systems.
- class fitsnap3lib.solvers.ridge.RIDGE(name, pt, config)
- perform_fit(a=None, b=None, w=None, fs_dict=None, trainall=False)
Perform fit on a linear system. If no args are supplied, will use fitting data in pt.shared_arrays.
- Parameters:
a (np.array) – Optional “A” matrix.
b (np.array) – Optional Truth array.
w (np.array) – Optional Weight array.
fs_dict (dict) – Optional dictionary containing a Testing key of which A matrix rows should not be trained.
trainall (bool) – Optional boolean declaring whether to train on all samples in the A matrix.
The fit is stored as a member fs.solver.fit.
3.4.3. PYTORCH
This class inherits from the Solver class, since it is a particular solver option.
- class fitsnap3lib.solvers.pytorch.PYTORCH(name)
Dummy class for factory to read if torch is not available for import.