StochTree 0.4.1
Loading...
Searching...
No Matches
ordinal_sampler.h
1
5#ifndef STOCHTREE_ORDINAL_SAMPLER_H_
6#define STOCHTREE_ORDINAL_SAMPLER_H_
7
8#include <stochtree/data.h>
9#include <stochtree/ensemble.h>
10#include <stochtree/gamma_sampler.h>
11#include <stochtree/partition_tracker.h>
12#include <stochtree/tree.h>
13
14#include <Eigen/Dense>
15#include <random>
16#include <vector>
17
18namespace StochTree {
19
20static double sample_truncated_exponential_low_high(double u, double rate, double low, double high) {
21 return -std::log((1-u)*std::exp(-rate*low) + u*std::exp(-rate*high))/rate;
22}
23
24static double sample_truncated_exponential_low(double u, double rate, double low) {
25 return -std::log((1-u)*std::exp(-rate*low))/rate;
26}
27
28static double sample_truncated_exponential_high(double u, double rate, double high) {
29 return -std::log1p(u*std::expm1(-high*rate))/rate;
30}
31
32static double sample_exponential(double u, double rate) {
33 return -std::log1p(-u)/rate;
34}
35
45 public:
47 gamma_sampler_ = GammaSampler();
48 }
50
62 static double SampleTruncatedExponential(std::mt19937& gen, double rate, double low = 0.0, double high = 1.0);
63
71 void UpdateLatentVariables(ForestDataset& dataset, Eigen::VectorXd& outcome, std::mt19937& gen);
72
83 void UpdateGammaParams(ForestDataset& dataset, Eigen::VectorXd& outcome,
84 double alpha_gamma, double beta_gamma,
85 double gamma_0, std::mt19937& gen);
86
93
94 private:
95 GammaSampler gamma_sampler_;
96};
97
98} // namespace StochTree
99
100#endif // STOCHTREE_ORDINAL_SAMPLER_H_
API for loading and accessing data used to sample tree ensembles The covariates / bases / weights use...
Definition data.h:271
Definition gamma_sampler.h:10
Sampler for ordinal model hyperparameters.
Definition ordinal_sampler.h:44
void UpdateCumulativeExpSums(ForestDataset &dataset)
Update cumulative exponential sums (seg)
void UpdateLatentVariables(ForestDataset &dataset, Eigen::VectorXd &outcome, std::mt19937 &gen)
Update truncated exponential latent variables (Z)
static double SampleTruncatedExponential(std::mt19937 &gen, double rate, double low=0.0, double high=1.0)
Sample from truncated exponential distribution.
void UpdateGammaParams(ForestDataset &dataset, Eigen::VectorXd &outcome, double alpha_gamma, double beta_gamma, double gamma_0, std::mt19937 &gen)
Update gamma cutpoint parameters.
A collection of random number generation utilities.
Definition category_tracker.h:36