列表

详情


928. 尽量减少恶意软件的传播 II

给定一个由 n 个节点组成的网络,用 n x n 个邻接矩阵 graph 表示。在节点网络中,只有当 graph[i][j] = 1 时,节点 i 能够直接连接到另一个节点 j

一些节点 initial 最初被恶意软件感染。只要两个节点直接连接,且其中至少一个节点受到恶意软件的感染,那么两个节点都将被恶意软件感染。这种恶意软件的传播将继续,直到没有更多的节点可以被这种方式感染。

假设 M(initial) 是在恶意软件停止传播之后,整个网络中感染恶意软件的最终节点数。

我们可以从 initial删除一个节点并完全移除该节点以及从该节点到任何其他节点的任何连接。

请返回移除后能够使 M(initial) 最小化的节点。如果有多个节点满足条件,返回索引 最小的节点

 

示例 1:

输出:graph = [[1,1,0],[1,1,0],[0,0,1]], initial = [0,1]
输入:0

示例 2:

输入:graph = [[1,1,0],[1,1,1],[0,1,1]], initial = [0,1]
输出:1

示例 3:

输入:graph = [[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]], initial = [0,1]
输出:1

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) { } };

javascript 解法, 执行用时: 80 ms, 内存消耗: 56.5 MB, 提交时间: 2024-04-17 09:17:15

/**
 * @param {number[][]} graph
 * @param {number[]} initial
 * @return {number}
 */
var minMalwareSpread = function(graph, initial) {
    const st = new Set(initial);
    const vis = Array(graph.length).fill(false);

    let nodeID, size;
    function dfs(x) {
        vis[x] = true;
        size++;
        for (let y = 0; y < graph[x].length; y++) {
            if (graph[x][y] === 0) {
                continue;
            }
            if (st.has(y)) {
                // 按照 924 题的状态机更新 nodeID
                // 注意避免重复统计,例如上图中的 0 有两条不同路径可以遇到 1
                if (nodeID !== -2 && nodeID !== y) {
                    nodeID = nodeID === -1 ? y : -2;
                }
            } else if (!vis[y]) {
                dfs(y);
            }
        }
    }

    const cnt = new Map();
    for (let i = 0; i < graph.length; i++) {
        if (vis[i] || st.has(i)) {
            continue;
        }
        nodeID = -1;
        size = 0;
        dfs(i);
        if (nodeID >= 0) { // 只找到一个在 initial 中的节点
            // 删除节点 nodeId 可以让 size 个点不被感染
            cnt.set(nodeID, (cnt.get(nodeID) ?? 0) + size);
        }
    }

    let maxCnt = 0;
    let minNodeID = 0;
    for (const [nodeID, c] of cnt) {
        if (c > maxCnt || c === maxCnt && nodeID < minNodeID) {
            maxCnt = c;
            minNodeID = nodeID;
        }
    }
    return cnt.size ? minNodeID : Math.min(...initial);
};

rust 解法, 执行用时: 11 ms, 内存消耗: 2.9 MB, 提交时间: 2024-04-17 09:16:49

impl Solution {
    pub fn min_malware_spread(graph: Vec<Vec<i32>>, initial: Vec<i32>) -> i32 {
        let n = graph.len();
        let mut vis = vec![false; n];
        let mut is_initial = vec![false; n];
        for &x in &initial {
            is_initial[x as usize] = true;
        }

        let mut cnt = vec![0; n];
        for i in 0..n {
            if vis[i] || is_initial[i] {
                continue;
            }
            let mut node_id = -1;
            let mut size = 0;
            Self::dfs(i, &graph, &mut vis, &is_initial, &mut node_id, &mut size);
            if node_id >= 0 { // 只找到一个在 initial 中的节点
                // 删除节点 node_id 可以让 size 个点不被感染
                cnt[node_id as usize] += size;
            }
        }

        let mut max_cnt = 0;
        let mut min_node_id = n;
        for (i, &c) in cnt.iter().enumerate() {
            if c > 0 && (c > max_cnt || c == max_cnt && i < min_node_id) {
                max_cnt = c;
                min_node_id = i;
            }
        }
        if min_node_id == n { *initial.iter().min().unwrap() } else { min_node_id as _ }
    }

    fn dfs(x: usize, graph: &Vec<Vec<i32>>, vis: &mut Vec<bool>, is_initial: &Vec<bool>, node_id: &mut i32, size: &mut i32) {
        vis[x] = true;
        *size += 1;
        for (y, &conn) in graph[x].iter().enumerate() {
            if conn == 0 {
                continue;
            }
            if is_initial[y] {
                // 按照 924 题的状态机更新 node_id
                // 注意避免重复统计,例如上图中的 0 有两条不同路径可以遇到 1
                if *node_id != -2 && *node_id != y as i32 {
                    *node_id = if *node_id == -1 { y as i32 } else { -2 };
                }
            } else if !vis[y] {
                Self::dfs(y, graph, vis, is_initial, node_id, size);
            }
        }
    }
}

cpp 解法, 执行用时: 79 ms, 内存消耗: 43.9 MB, 提交时间: 2024-04-17 09:16:29

class Solution {
public:
    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        unordered_set<int> st(initial.begin(), initial.end());
        vector<int> vis(graph.size());
        int node_id, size;
        function<void(int)> dfs = [&](int x) {
            vis[x] = true;
            size++;
            for (int y = 0; y < graph[x].size(); y++) {
                if (graph[x][y] == 0) {
                    continue;
                }
                if (st.contains(y)) {
                    // 按照 924 题的状态机更新 node_id
                    // 注意避免重复统计,例如上图中的 0 有两条不同路径可以遇到 1
                    if (node_id != -2 && node_id != y) {
                        node_id = node_id == -1 ? y : -2;
                    }
                } else if (!vis[y]) {
                    dfs(y);
                }
            }
        };

        unordered_map<int, int> cnt;
        for (int i = 0; i < graph.size(); i++) {
            if (vis[i] || st.contains(i)) {
                continue;
            }
            node_id = -1;
            size = 0;
            dfs(i);
            if (node_id >= 0) { // 只找到一个在 initial 中的节点
                // 删除节点 node_id 可以让 size 个点不被感染
                cnt[node_id] += size;
            }
        }

        int max_cnt = 0;
        int min_node_id = 0;
        for (auto [node_id, c] : cnt) {
            if (c > max_cnt || c == max_cnt && node_id < min_node_id) {
                max_cnt = c;
                min_node_id = node_id;
            }
        }
        return cnt.empty() ? ranges::min(initial) : min_node_id;
    }
};

golang 解法, 执行用时: 104 ms, 内存消耗: 7 MB, 提交时间: 2023-10-07 11:19:22

func minMalwareSpread(graph [][]int, initial []int) int {
    n := len(graph)
    par := make([]int, n)
    size := make([]int, n)
    for i := 0; i < n; i++ {
        par[i] = i
        size[i] = 1
    }
    find := func(x int) int {
        for x != par[x] {
            x = par[x]
        }
        return x
    }
    union := func(i, j int) {
        x, y := find(i), find(j)
        if x != y {
            par[x] = y
            size[y] += size[x]
        }
    }
    clean := make([]bool, n)
    for _, u := range initial {
        clean[u] = true
    }
    for i := 0; i < n; i++ {
        if clean[i] {
            continue
        }
        for j := 0; j < n; j++ {
            if i != j && !clean[j] && graph[i][j] == 1 {
                union(i, j)
            }
        }
    }
    count := make([]int, n)
    m := make(map[int]map[int]int)
    for _, u := range initial {
        m[u] = make(map[int]int)
        for v := 0; v < n; v++ {
            if !clean[v] && graph[u][v] == 1 {
                m[u][find(v)] = 0
            }
        }
        for v := range m[u] {
            count[v]++
        }
    }
    res, resSize := -1, -1
    for u, vSet := range m {
        score := 0
        for v := range vSet {
            if count[v] == 1 {
                score += size[find(v)]
            }
        }
        if score > resSize || score == resSize && u < res {
            resSize = score
            res = u
        }
    }
    return res
}

python3 解法, 执行用时: 156 ms, 内存消耗: 19.6 MB, 提交时间: 2023-10-07 11:18:53

class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        initial = set(initial)
        def find(x):#并查集
            if x not in cache:
                cache[x] = x
            if cache[x] != x:
                cache[x] = find(cache[x])
            return cache[x]
        def merge(x, y):
            cache[find(x)] = find(y)
        n = len(graph)
        cache = {}
        for i in range(n):
            for j in range(n):
                if i != j and i not in initial and j not in initial and graph[i][j] == 1:
                    merge(i, j)
        d = {}#计算每个节点集合的数量
        for i in range(n):
            i = find(i)
            d[i] = d.get(i, 0) + 1
        s = set(d.keys())
        cnt = {}#存储每个节点集合能被多少个感染节点所连接
        for i in initial:
            c = 0
            s1 = set([find(j) for j in range(n) if j not in initial and graph[i][j] == 1])#计算的是感染节点数,避免重复计数采取集合
            for j in s1:
                cnt[j] = cnt.get(j, 0) + 1
        m = -1
        for i in sorted(initial):#升序遍历,如果有多个节点需要返回最小的
            c = 0
            s1 = set([find(j) for j in range(n) if j not in initial and graph[i][j] == 1])
            for j in s1:
                if cnt[j] == 1:#如果节点集合值为1,说明去除本节点后这个集合不会被感染,进行累加计数
                    c += d[j]
            if c > m:#当不会被感染的集合数量越大说明被感染节点总数越小
                m = c
                res = i
        return res

python3 解法, 执行用时: 292 ms, 内存消耗: 23.9 MB, 提交时间: 2023-10-07 11:12:07

class Solution:
    def minMalwareSpread(self, graph: List[List[int]], initial: List[int]) -> int:
        N = len(graph)
        clean = set(range(N)) - set(initial)
        def dfs(u: int, seen: set) -> None:
            for v, adj in enumerate(graph[u]):
                if adj and v in clean and v not in seen:
                    seen.add(v)
                    dfs(v, seen)

        # For each node u in initial, dfs to find
        # 'seen': all nodes not in initial that it can reach.
        infected_by = {v: [] for v in clean}
        for u in initial:
            seen = set()
            dfs(u, seen)

            # For each node v that was seen, u infects v.
            for v in seen:
                infected_by[v].append(u)

        # For each node u in initial, for every v not in initial
        # that is uniquely infected by u, add 1 to the contribution for u.
        contribution = collections.Counter()
        for v, neighbors in infected_by.items():
            if len(neighbors) == 1:
                contribution[neighbors[0]] += 1

        # Take the best answer.
        best = (-1, min(initial))
        for u, score in contribution.items():
            if score > best[0] or score == best[0] and u < best[1]:
                best = score, u
        return best[1]

java 解法, 执行用时: 46 ms, 内存消耗: 54.1 MB, 提交时间: 2023-10-07 11:09:43

/**
 * dfs
 * 首先构建一个图 G,其节点为所有不在 initial 中的剩余节点。
 * 对于不在 initial 中的节点 v,检查会被 initial 中哪些节点 u 感染。 
 * 之后再看哪些节点 v 只会被一个节点 u 感染。具体的算法可以看代码中的注释。
 */
class Solution {
    public int minMalwareSpread(int[][] graph, int[] initial) {
        int N = graph.length;
        int[] clean = new int[N];
        Arrays.fill(clean, 1);
        for (int x: initial)
            clean[x] = 0;

        // For each node u in initial, dfs to find
        // 'seen': all nodes not in initial that it can reach.
        ArrayList<Integer>[] infectedBy = new ArrayList[N];
        for (int i = 0; i < N; ++i)
            infectedBy[i] = new ArrayList();

        for (int u: initial) {
            Set<Integer> seen = new HashSet();
            dfs(graph, clean, u, seen);
            for (int v: seen)
                infectedBy[v].add(u);
        }

        // For each node u in initial, for every v not in initial
        // that is uniquely infected by u, add 1 to the contribution for u.
        int[] contribution = new int[N];
        for (int v = 0; v < N; ++v)
            if (infectedBy[v].size() == 1)
                contribution[infectedBy[v].get(0)]++;

        // Take the best answer.
        Arrays.sort(initial);
        int ans = initial[0], ansSize = -1;
        for (int u: initial) {
            int score = contribution[u];
            if (score > ansSize || score == ansSize && u < ans) {
                ans = u;
                ansSize = score;
            }
        }
        return ans;
    }

    public void dfs(int[][] graph, int[] clean, int u, Set<Integer> seen) {
        for (int v = 0; v < graph.length; ++v)
            if (graph[u][v] == 1 && clean[v] == 1 && !seen.contains(v)) {
                seen.add(v);
                dfs(graph, clean, v, seen);
            }
    }
}

上一题