列表

详情


1569. 将子数组重新排序得到同一个二叉查找树的方案数

给你一个数组 nums 表示 1 到 n 的一个排列。我们按照元素在 nums 中的顺序依次插入一个初始为空的二叉查找树(BST)。请你统计将 nums 重新排序后,统计满足如下条件的方案数:重排后得到的二叉查找树与 nums 原本数字顺序得到的二叉查找树相同。

比方说,给你 nums = [2,1,3],我们得到一棵 2 为根,1 为左孩子,3 为右孩子的树。数组 [2,3,1] 也能得到相同的 BST,但 [3,2,1] 会得到一棵不同的 BST 。

请你返回重排 nums 后,与原数组 nums 得到相同二叉查找树的方案数。

由于答案可能会很大,请将结果对 10^9 + 7 取余数。

 

示例 1:

输入:nums = [2,1,3]
输出:1
解释:我们将 nums 重排, [2,3,1] 能得到相同的 BST 。没有其他得到相同 BST 的方案了。

示例 2:

输入:nums = [3,4,5,1,2]
输出:5
解释:下面 5 个数组会得到相同的 BST:
[3,1,2,4,5]
[3,1,4,2,5]
[3,1,4,5,2]
[3,4,1,2,5]
[3,4,1,5,2]

示例 3:

输入:nums = [1,2,3]
输出:0
解释:没有别的排列顺序能得到相同的 BST 。

示例 4:

输入:nums = [3,1,2,5,4,6]
输出:19

示例  5:

输入:nums = [9,4,2,1,3,6,5,7,8,14,11,10,12,13,16,15,17,18]
输出:216212978
解释:得到相同 BST 的方案数是 3216212999。将它对 10^9 + 7 取余后得到 216212978。

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: int numOfWays(vector<int>& nums) { } };

python3 解法, 执行用时: 112 ms, 内存消耗: 21.3 MB, 提交时间: 2023-10-10 23:46:12

class Solution:
    def numOfWays(self, nums: List[int]) -> int:
        def dac(arr):
            if len(arr) <= 1: return 1
            less, more = [], []
            for x in arr[1:]: (less if x < arr[0] else more).append(x)
            return comb(len(arr) - 1, len(less)) * dac(less) * dac(more) % MOD
        
        MOD = int(1e9) + 7
        comb = cache(math.comb)
        return (dac(nums) - 1 + MOD) % MOD

java 解法, 执行用时: 88 ms, 内存消耗: 57.4 MB, 提交时间: 2023-10-10 23:44:56

class Solution {
    static final int MOD = 1000000007;
    long[][] c;

    public int numOfWays(int[] nums) {
        int n = nums.length;
        if (n == 1) {
            return 0;
        }

        c = new long[n][n];
        c[0][0] = 1;
        for (int i = 1; i < n; ++i) {
            c[i][0] = 1;
            for (int j = 1; j < n; ++j) {
                c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % MOD;
            }
        }

        TreeNode root = new TreeNode(nums[0]);
        for (int i = 1; i < n; ++i) {
            int val = nums[i];
            insert(root, val);
        }

        dfs(root);
        return (root.ans - 1 + MOD) % MOD;
    }

    public void insert(TreeNode root, int value) {
        TreeNode cur = root;
        while (true) {
            ++cur.size;
            if (value < cur.value) {
                if (cur.left == null) {
                    cur.left = new TreeNode(value);
                    return;
                }
                cur = cur.left;
            } else {
                if (cur.right == null) {
                    cur.right = new TreeNode(value);
                    return;
                }
                cur = cur.right;
            }
        }
    }

    public void dfs(TreeNode node) {
        if (node == null) {
            return;
        }
        dfs(node.left);
        dfs(node.right);
        int lsize = node.left != null ? node.left.size : 0;
        int rsize = node.right != null ? node.right.size : 0;
        int lans = node.left != null ? node.left.ans : 1;
        int rans = node.right != null ? node.right.ans : 1;
        node.ans = (int) (c[lsize + rsize][lsize] % MOD * lans % MOD * rans % MOD);
    }
}

class TreeNode {
    TreeNode left;
    TreeNode right;
    int value;
    int size;
    int ans;

    TreeNode(int value) {
        this.value = value;
        this.size = 1;
        this.ans = 0;
    }
}

cpp 解法, 执行用时: 388 ms, 内存消耗: 141.4 MB, 提交时间: 2023-10-10 23:44:38

struct TNode {
    TNode* left;
    TNode* right;
    int value;
    int size;
    int ans;
    
    TNode(int val): left(nullptr), right(nullptr), value(val), size(1), ans(0) {}
};

class Solution {
private:
    static constexpr int mod = 1000000007;
    vector<vector<int>> c;

public:
    void insert(TNode* root, int val) {
        TNode* cur = root;
        while (true) {
            ++cur->size;
            if (val < cur->value) {
                if (!cur->left) {
                    cur->left = new TNode(val);
                    return;
                }
                cur = cur->left;
            }
            else {
                if (!cur->right) {
                    cur->right = new TNode(val);
                    return;
                }
                cur = cur->right;
            }
        }
    }

    void dfs(TNode* node) {
        if (!node) {
            return;
        }
        dfs(node->left);
        dfs(node->right);
        int lsize = node->left ? node->left->size : 0;
        int rsize = node->right ? node->right->size : 0;
        int lans = node->left ? node->left->ans : 1;
        int rans = node->right ? node->right->ans : 1;
        node->ans = (long long)c[lsize + rsize][lsize] % mod * lans % mod * rans % mod;
    }

    int numOfWays(vector<int>& nums) {
        int n = nums.size();
        if (n == 1) {
            return 0;
        }

        c.assign(n, vector<int>(n));
        c[0][0] = 1;
        for (int i = 1; i < n; ++i) {
            c[i][0] = 1;
            for (int j = 1; j < n; ++j) {
                c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
            }
        }

        TNode* root = new TNode(nums[0]);
        for (int i = 1; i < n; ++i) {
            int val = nums[i];
            insert(root, val);
        }

        dfs(root);
        return (root->ans - 1 + mod) % mod;
    }
};

上一题