StochTree 0.0.1
Loading...
Searching...
No Matches
prior.h
1
5#ifndef STOCHTREE_PRIOR_H_
6#define STOCHTREE_PRIOR_H_
7
8#include <Eigen/Dense>
9#include <stochtree/log.h>
10
11namespace StochTree {
12
14 public:
16 virtual ~RandomEffectsGaussianPrior() = default;
17};
18
20 public:
21 RandomEffectsRegressionGaussianPrior(double a, double b, int32_t num_components, int32_t num_groups) {
22 a_ = a;
23 b_ = b;
24 num_components_ = num_components;
25 num_groups_ = num_groups;
26 }
28 double GetPriorVarianceShape() {return a_;}
29 double GetPriorVarianceScale() {return b_;}
30 int32_t GetNumComponents() {return num_components_;}
31 int32_t GetNumGroups() {return num_groups_;}
32 void SetPriorVarianceShape(double a) {a_ = a;}
33 void SetPriorVarianceScale(double b) {b_ = b;}
34 void SetNumComponents(int32_t num_components) {num_components_ = num_components;}
35 void SetNumGroups(int32_t num_groups) {num_groups_ = num_groups;}
36 private:
37 double a_;
38 double b_;
39 int32_t num_components_;
40 int32_t num_groups_;
41};
42
43class TreePrior {
44 public:
45 TreePrior(double alpha, double beta, int32_t min_samples_in_leaf, int32_t max_depth = -1) {
46 alpha_ = alpha;
47 beta_ = beta;
48 min_samples_in_leaf_ = min_samples_in_leaf;
49 max_depth_ = max_depth;
50 }
51 ~TreePrior() {}
52 double GetAlpha() {return alpha_;}
53 double GetBeta() {return beta_;}
54 int32_t GetMinSamplesLeaf() {return min_samples_in_leaf_;}
55 int32_t GetMaxDepth() {return max_depth_;}
56 void SetAlpha(double alpha) {alpha_ = alpha;}
57 void SetBeta(double beta) {beta_ = beta;}
58 void SetMinSamplesLeaf(int32_t min_samples_in_leaf) {min_samples_in_leaf_ = min_samples_in_leaf;}
59 void SetMaxDepth(int32_t max_depth) {max_depth_ = max_depth;}
60 private:
61 double alpha_;
62 double beta_;
63 int32_t min_samples_in_leaf_;
64 int32_t max_depth_;
65};
66
68 public:
69 IGVariancePrior(double shape, double scale) {
70 shape_ = shape;
71 scale_ = scale;
72 }
74 double GetShape() {return shape_;}
75 double GetScale() {return scale_;}
76 void SetShape(double shape) {shape_ = shape;}
77 void SetScale(double scale) {scale_ = scale;}
78 private:
79 double shape_;
80 double scale_;
81};
82
83} // namespace StochTree
84
85#endif // STOCHTREE_PRIOR_H_
Definition prior.h:67
Definition prior.h:43
Definition category_tracker.h:40