StochTree 0.0.1
Loading...
Searching...
No Matches
Public Member Functions | List of all members
StochTree::ForestDataset Class Reference

API for loading and accessing data used to sample tree ensembles The covariates / bases / weights used in sampling forests are stored internally as a ForestDataset by the sampling functions (see Forest Sampler API). More...

#include <data.h>

Public Member Functions

 ForestDataset ()
 Default constructor. No data is loaded at construction time.
 
void AddCovariates (double *data_ptr, data_size_t num_row, int num_col, bool is_row_major)
 Copy / load covariates from raw memory buffer (often pointer to data in a R matrix or numpy array)
 
void AddBasis (double *data_ptr, data_size_t num_row, int num_col, bool is_row_major)
 Copy / load basis matrix from raw memory buffer (often pointer to data in a R matrix or numpy array)
 
void AddVarianceWeights (double *data_ptr, data_size_t num_row)
 Copy / load variance weights from raw memory buffer (often pointer to data in a R vector or numpy array)
 
void AddCovariatesFromCSV (std::string filename, std::string column_index_string, bool header=true, bool precise_float_parser=false)
 Copy / load covariates from CSV file.
 
void AddBasisFromCSV (std::string filename, std::string column_index_string, bool header=true, bool precise_float_parser=false)
 Copy / load basis matrix from CSV file.
 
void AddVarianceWeightsFromCSV (std::string filename, int32_t column_index, bool header=true, bool precise_float_parser=false)
 Copy / load variance / case weights from CSV file.
 
bool HasCovariates ()
 Whether or not a ForestDataset has (yet) loaded covariate data.
 
bool HasBasis ()
 Whether or not a ForestDataset has (yet) loaded basis data.
 
bool HasVarWeights ()
 Whether or not a ForestDataset has (yet) loaded variance weights.
 
data_size_t NumObservations ()
 Number of observations (rows) in the dataset.
 
int NumCovariates ()
 Number of covariate columns in the dataset.
 
int NumBasis ()
 Number of bases in the dataset. This is 0 if the dataset has not been provided a basis matrix.
 
double CovariateValue (data_size_t row, int col)
 Returns a dataset's covariate value stored at (row, col)
 
double BasisValue (data_size_t row, int col)
 Returns a dataset's basis value stored at (row, col)
 
double VarWeightValue (data_size_t row)
 Returns a dataset's variance weight stored at element row
 
Eigen::MatrixXd & GetCovariates ()
 Return a reference to the raw Eigen::MatrixXd storing the covariate data.
 
Eigen::MatrixXd & GetBasis ()
 Return a reference to the raw Eigen::MatrixXd storing the basis data.
 
Eigen::VectorXd & GetVarWeights ()
 Return a reference to the raw Eigen::VectorXd storing the variance weights.
 
void UpdateBasis (double *data_ptr, data_size_t num_row, int num_col, bool is_row_major)
 Update the data in the internal basis matrix to new values stored in a raw double array.
 
void UpdateVarWeights (double *data_ptr, data_size_t num_row, bool exponentiate=true)
 Update the data in the internal variance weight vector to new values stored in a raw double array.
 
void SetCovariateValue (data_size_t row_id, int col, double new_value)
 Update an observation in the internal covariate matrix to a new value.
 
void SetBasisValue (data_size_t row_id, int col, double new_value)
 Update an observation in the internal basis matrix to a new value.
 
void SetVarWeightValue (data_size_t row_id, double new_value, bool exponentiate=true)
 Update an observation in the internal variance weight vector to a new value.
 

Detailed Description

API for loading and accessing data used to sample tree ensembles The covariates / bases / weights used in sampling forests are stored internally as a ForestDataset by the sampling functions (see Forest Sampler API).

Member Function Documentation

◆ AddCovariates()

void StochTree::ForestDataset::AddCovariates ( double *  data_ptr,
data_size_t  num_row,
int  num_col,
bool  is_row_major 
)
inline

Copy / load covariates from raw memory buffer (often pointer to data in a R matrix or numpy array)

Parameters
data_ptrPointer to first element of a contiguous array of data storing a covariate matrix
num_rowNumber of rows in the covariate matrix
num_colNumber of columns / covariates in the covariate matrix
is_row_majorWhether or not the data in data_ptr are organized in a row-major or column-major fashion

◆ AddBasis()

void StochTree::ForestDataset::AddBasis ( double *  data_ptr,
data_size_t  num_row,
int  num_col,
bool  is_row_major 
)
inline

Copy / load basis matrix from raw memory buffer (often pointer to data in a R matrix or numpy array)

Parameters
data_ptrPointer to first element of a contiguous array of data storing a basis matrix
num_rowNumber of rows in the basis matrix
num_colNumber of columns in the basis matrix
is_row_majorWhether or not the data in data_ptr are organized in a row-major or column-major fashion

◆ AddVarianceWeights()

void StochTree::ForestDataset::AddVarianceWeights ( double *  data_ptr,
data_size_t  num_row 
)
inline

Copy / load variance weights from raw memory buffer (often pointer to data in a R vector or numpy array)

Parameters
data_ptrPointer to first element of a contiguous array of data storing weights
num_rowNumber of rows in the weight vector

◆ AddCovariatesFromCSV()

void StochTree::ForestDataset::AddCovariatesFromCSV ( std::string  filename,
std::string  column_index_string,
bool  header = true,
bool  precise_float_parser = false 
)
inline

Copy / load covariates from CSV file.

Parameters
filenameName of the file (including any necessary path prefixes)
column_index_stringComma-delimited string listing columns to extract into covariates matrix

◆ AddBasisFromCSV()

void StochTree::ForestDataset::AddBasisFromCSV ( std::string  filename,
std::string  column_index_string,
bool  header = true,
bool  precise_float_parser = false 
)
inline

Copy / load basis matrix from CSV file.

Parameters
filenameName of the file (including any necessary path prefixes)
column_index_stringComma-delimited string listing columns to extract into covariates matrix

◆ AddVarianceWeightsFromCSV()

void StochTree::ForestDataset::AddVarianceWeightsFromCSV ( std::string  filename,
int32_t  column_index,
bool  header = true,
bool  precise_float_parser = false 
)
inline

Copy / load variance / case weights from CSV file.

Parameters
filenameName of the file (including any necessary path prefixes)
column_indexInteger index of column containing weights

◆ CovariateValue()

double StochTree::ForestDataset::CovariateValue ( data_size_t  row,
int  col 
)
inline

Returns a dataset's covariate value stored at (row, col)

Parameters
rowRow number to query in the covariate matrix
colColumn number to query in the covariate matrix

◆ BasisValue()

double StochTree::ForestDataset::BasisValue ( data_size_t  row,
int  col 
)
inline

Returns a dataset's basis value stored at (row, col)

Parameters
rowRow number to query in the basis matrix
colColumn number to query in the basis matrix

◆ VarWeightValue()

double StochTree::ForestDataset::VarWeightValue ( data_size_t  row)
inline

Returns a dataset's variance weight stored at element row

Parameters
rowIndex to query in the weight vector

◆ GetCovariates()

Eigen::MatrixXd & StochTree::ForestDataset::GetCovariates ( )
inline

Return a reference to the raw Eigen::MatrixXd storing the covariate data.

Returns
Reference to internal Eigen::MatrixXd

◆ GetBasis()

Eigen::MatrixXd & StochTree::ForestDataset::GetBasis ( )
inline

Return a reference to the raw Eigen::MatrixXd storing the basis data.

Returns
Reference to internal Eigen::MatrixXd

◆ GetVarWeights()

Eigen::VectorXd & StochTree::ForestDataset::GetVarWeights ( )
inline

Return a reference to the raw Eigen::VectorXd storing the variance weights.

Returns
Reference to internal Eigen::VectorXd

◆ UpdateBasis()

void StochTree::ForestDataset::UpdateBasis ( double *  data_ptr,
data_size_t  num_row,
int  num_col,
bool  is_row_major 
)
inline

Update the data in the internal basis matrix to new values stored in a raw double array.

Parameters
data_ptrPointer to first element of a contiguous array of data storing a basis matrix
num_rowNumber of rows in the basis matrix
num_colNumber of columns in the basis matrix
is_row_majorWhether or not the data in data_ptr are organized in a row-major or column-major fashion

◆ UpdateVarWeights()

void StochTree::ForestDataset::UpdateVarWeights ( double *  data_ptr,
data_size_t  num_row,
bool  exponentiate = true 
)
inline

Update the data in the internal variance weight vector to new values stored in a raw double array.

Parameters
data_ptrPointer to first element of a contiguous array of data storing a weight vector
num_rowNumber of rows in the weight vector
exponentiateWhether or not inputs should be exponentiated before being saved to var weight vector

◆ SetCovariateValue()

void StochTree::ForestDataset::SetCovariateValue ( data_size_t  row_id,
int  col,
double  new_value 
)
inline

Update an observation in the internal covariate matrix to a new value.

Parameters
rowRow number to be overwritten in the covariate matrix
colColumn number to be overwritten in the covariate matrix
new_valueNew covariate value

◆ SetBasisValue()

void StochTree::ForestDataset::SetBasisValue ( data_size_t  row_id,
int  col,
double  new_value 
)
inline

Update an observation in the internal basis matrix to a new value.

Parameters
rowRow number to be overwritten in the basis matrix
colColumn number to be overwritten in the basis matrix
new_valueNew basis value

◆ SetVarWeightValue()

void StochTree::ForestDataset::SetVarWeightValue ( data_size_t  row_id,
double  new_value,
bool  exponentiate = true 
)
inline

Update an observation in the internal variance weight vector to a new value.

Parameters
row_idRow ID in the variance weight vector to be overwritten
new_valueNew variance weight value
exponentiateWhether or not input should be exponentiated before being saved to var weight vector

The documentation for this class was generated from the following file: