Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function to save Random Forest model to txt file. #4

Open
crimson-luis opened this issue Jun 3, 2022 · 0 comments
Open

Function to save Random Forest model to txt file. #4

crimson-luis opened this issue Jun 3, 2022 · 0 comments

Comments

@crimson-luis
Copy link

crimson-luis commented Jun 3, 2022

After reading some issues I realized that would be useful to share a function that I made while working with your project.
This function receives a sklearn Random Forest class and read each tree writing a list with all info.

It is possible that it is not 100% correct.

def model_to_txt(self, index, show: bool = True, save: bool = False):
    # https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html#sphx-glr-auto-examples-tree-plot-unveil-tree-structure-py
    forest = self.estimators_
    model_info = list()
    model_info.append(
        f"DATASET_NAME: {config['DATASET']['NAME']}.train{index}.csv"
        f"\nENSEMBLE: RF"
        f"\nNB_TREES: {len(forest)}"
        f"\nNB_FEATURES: {forest[0].tree_.n_features}"
        f"\nNB_CLASSES: {forest[0].tree_.n_classes[0]}"
        f"\nMAX_TREE_DEPTH: {forest[0].tree_.max_depth}"
        "\nFormat: node / node type (LN - leave node, IN - internal node) "
        "left child / right child / feature / threshold / node_depth / "
        "majority class (starts with index 0)"
    )
    for tree_idx, est in enumerate(forest):
        tree = est.tree_
        n_nodes = tree.node_count
        children_left = tree.children_left
        children_right = tree.children_right

        node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
        is_leaves = np.zeros(shape=n_nodes, dtype=bool)
        stack = [(0, 0)]  # start with the root node id (0) and its depth (0)
        model_info.append(f"\n\n[TREE {tree_idx}]\nNB_NODES: {n_nodes}")
        while len(stack) > 0:
            node_id, depth = stack.pop()
            node_depth[node_id] = depth

            if children_left[node_id] != children_right[node_id]:
                stack.append((children_left[node_id], depth + 1))
                stack.append((children_right[node_id], depth + 1))
            else:
                is_leaves[node_id] = True
        for i in range(n_nodes):
            class_idx = np.argmax(tree.value[i][0])
            if is_leaves[i]:
                model_info.append(f"\n{i} LN -1 -1 -1 -1 {node_depth[i]} {class_idx}")
            else:
                model_info.append(
                    f"\n{i} IN {children_left[i]} {children_right[i]} "
                    f"{tree.feature[i]} {tree.threshold[i]} {node_depth[i]} -1"
                )
    model_info.append("\n\n")
    if show:
        print(*model_info)
    if save:
        with open(
                f"./data/processed/forests/{config['DATASET']['NAME']}.RF{index}.txt",
                "w"
        ) as f:
            for item in model_info:
                f.write(item)
@crimson-luis crimson-luis changed the title Function to save Ranfom Forest model to txt file. Function to save Random Forest model to txt file. Jun 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant