kernel.compute_forest_leaf_indices
kernel.compute_forest_leaf_indices(
model_object,
covariates,
forest_type=None,
propensity=None,
forest_inds=None,
)Compute and return a vector representation of a forest’s leaf predictions for every observation in a dataset.
The vector has a “tree-major” format that can be easily re-represented as as a CSR sparse matrix: elements are organized so that the first n elements correspond to leaf predictions for all n observations in a dataset for the first tree in an ensemble, the next n elements correspond to predictions for the second tree and so on. The “data” for each element corresponds to a uniquely mapped column index that corresponds to a single leaf of a single tree (i.e. if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2’s leaf indices begin at 3, etc…).
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| model_object | BARTModel, BCFModel, or ForestContainer | Object corresponding to a BART / BCF model with at least one forest sample, or a low-level ForestContainer object. |
required |
| covariates | np.array or pd.DataFrame | Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest. | required |
| forest_type | str | Which forest to use from model_object. Valid inputs depend on the model type, and whether or not a given forest was sampled in that model. See Notes for a mapping from model type to valid forest types. |
None |
| propensity | np.array |
Optional test set propensities. Must be provided if propensities were provided when the model was sampled. | None |
| forest_inds | int or np.ndarray | Indices of the forest sample(s) for which to compute leaf indices. If not provided, this function will return leaf indices for every sample of a forest. This function uses 0-indexing, so the first forest sample corresponds to forest_num = 0, and so on. |
None |
Returns
| Name | Type | Description |
|---|---|---|
Numpy array with dimensions num_obs by num_trees, where num_obs is the number of rows in covariates and num_trees is the number of trees in the relevant forest of model_object. |
Notes
Mapping from model type to forest types:
- BART
'mean':'mean': Extracts leaf indices for the mean forest'variance': Extracts leaf indices for the variance forest
- BCF
'prognostic': Extracts leaf indices for the prognostic forest'treatment': Extracts leaf indices for the treatment effect forest'variance': Extracts leaf indices for the variance forest
- ForestContainer
NULL: It is not necessary to disambiguate when this function is called directly on aForestSamplesobject. This is the default value of this