I believe that this answer is more correct than the other answers here:
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])
recurse(0, 1)
This prints out a valid Python function. Here's an example output for a tree that is trying to return its input, a number between 0 and 10.
def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else: # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else: # if f0 > 3.5
return [[ 4.]]
else: # if f0 > 4.5
return [[ 5.]]
else: # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else: # if f0 > 7.5
return [[ 8.]]
else: # if f0 > 8.5
return [[ 9.]]
Here are some stumbling blocks that I see in other answers:
- Using
tree_.threshold == -2
to decide whether a node is a leaf isn't a good idea. What if it's a real decision node with a threshold of -2? Instead, you should look at tree.feature
or tree.children_*
.
- The line
features = [feature_names[i] for i in tree_.feature]
crashes with my version of sklearn, because some values of tree.tree_.feature
are -2 (specifically for leaf nodes).
- There is no need to have multiple if statements in the recursive function, just one is fine.
Each instance of RandomForestClassifier
has an estimators_
attribute, which is a list of DecisionTreeClassifier
instances. The documentation shows that an instance of DecisionTreeClassifier
has a tree_
attribute, which is an instance of the (undocumented, I believe) Tree
class. Some exploration in the interpreter shows that each Tree
instance has a max_depth
parameter which appears to be what you're looking for -- again, it's undocumented.
In any case, if forest
is your instance of RandomForestClassifier
, then:
>>> [estimator.tree_.max_depth for estimator in forest.estimators_]
[9, 10, 9, 11, 9, 9, 11, 7, 13, 10]
should do the trick.
Each estimator also has a get_depth()
method than can be used to retrieve the same value with briefer syntax:
>>> [estimator.get_depth() for estimator in forest.estimators_]
[9, 10, 9, 11, 9, 9, 11, 7, 13, 10]
To avoid mixup, it should be noted that there is an attribute of each estimator (and not each estimator's tree_
) called max depth
which returns the setting of the parameter rather than the depth of the actual tree. How estimator.get_depth()
, estimator.tree_.max_depth
, and estimator.max_depth
relate to each other is clarified in the example below:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(n_estimators=3, random_state=4, max_depth=6)
iris = load_iris()
clf.fit(iris['data'], iris['target'])
[(est.get_depth(), est.tree_.max_depth, est.max_depth) for est in clf.estimators_]
Out:
[(6, 6, 6), (3, 3, 6), (4, 4, 6)]
Setting max depth to the default value None
would allow the first tree to expand to depth 7 and the output would be:
[(7, 7, None), (3, 3, None), (4, 4, None)]
Best Answer
This is explained in the documentation
So, basically, a sub-optimal greedy algorithm is repeated a number of times using random selections of features and samples (a similar technique used in random forests). The
random_state
parameter allows controlling these random choices.The interface documentation specifically states:
So, the random algorithm will be used in any case. Passing any value (whether a specific int, e.g., 0, or a
RandomState
instance), will not change that. The only rationale for passing in an int value (0 or otherwise) is to make the outcome consistent across calls: if you call this withrandom_state=0
(or any other value), then each and every time, you'll get the same result.