519. 随机翻转矩阵
给你一个 m x n
的二元矩阵 matrix
,且所有值被初始化为 0
。请你设计一个算法,随机选取一个满足 matrix[i][j] == 0
的下标 (i, j)
,并将它的值变为 1
。所有满足 matrix[i][j] == 0
的下标 (i, j)
被选取的概率应当均等。
尽量最少调用内置的随机函数,并且优化时间和空间复杂度。
实现 Solution
类:
Solution(int m, int n)
使用二元矩阵的大小 m
和 n
初始化该对象int[] flip()
返回一个满足 matrix[i][j] == 0
的随机下标 [i, j]
,并将其对应格子中的值变为 1
void reset()
将矩阵中所有的值重置为 0
示例:
输入 ["Solution", "flip", "flip", "flip", "reset", "flip"] [[3, 1], [], [], [], [], []] 输出 [null, [1, 0], [2, 0], [0, 0], null, [2, 0]] 解释 Solution solution = new Solution(3, 1); solution.flip(); // 返回 [1, 0],此时返回 [0,0]、[1,0] 和 [2,0] 的概率应当相同 solution.flip(); // 返回 [2, 0],因为 [1,0] 已经返回过了,此时返回 [2,0] 和 [0,0] 的概率应当相同 solution.flip(); // 返回 [0, 0],根据前面已经返回过的下标,此时只能返回 [0,0] solution.reset(); // 所有值都重置为 0 ,并可以再次选择下标返回 solution.flip(); // 返回 [2, 0],此时返回 [0,0]、[1,0] 和 [2,0] 的概率应当相同
提示:
1 <= m, n <= 104
flip
时,矩阵中至少存在一个值为 0 的格子。1000
次 flip
和 reset
方法。原站题解
cpp 解法, 执行用时: 16 ms, 内存消耗: 18.4 MB, 提交时间: 2023-06-05 14:23:08
class Solution { int r, c, k; // r: 矩阵的行数, c: 列数, k: 矩阵中元素的个数 unordered_map<int, int> dict; // 大脑中构建一个映射: x -> x, 默认用这个映射, 将不满足这个映射的特殊键值对存入哈希表 public: Solution(int m, int n) { r = m; c = n; k = r*c; } vector<int> flip() { /* 先把矩阵(二维数组)拉平成1维数组, 再进行随机处理 */ int key = rand() % k; int val = key; // 默认的映射规则: x -> x (x的范围是: 0 -> k-1) if (dict.count(key)) val = dict[key]; if (dict.count(k - 1)) // 当key处的kvp用过后, 用最后一个kvp(key = k-1)覆盖之, 然后删掉最后一个kvp, 就可以在剩下的数中实现随机化选择 { dict[key] = dict[k-1]; dict.erase(k - 1); } else dict[key] = k - 1; k--; // 表示删掉了末尾的一个数 int newRow = val / c; int newCol = val % c; return {newRow, newCol}; } void reset() { k = r*c; dict.clear(); } }; /** * Your Solution object will be instantiated and called as such: * Solution* obj = new Solution(m, n); * vector<int> param_1 = obj->flip(); * obj->reset(); */
javascript 解法, 执行用时: 288 ms, 内存消耗: 49 MB, 提交时间: 2023-06-05 14:22:36
/** * @param {number} m * @param {number} n */ var Solution = function(m, n) { this.m = m; this.n = n; this.total = m * n; this.bucketSize = Math.floor(Math.sqrt(this.total)); this.buckets = []; for (let i = 0; i < this.total; i += this.bucketSize) { this.buckets.push(new Set()); } }; /** * @return {number[]} */ Solution.prototype.flip = function() { const x = Math.floor(Math.random() * this.total); let sumZero = 0; let curr = 0; this.total--; for (const bucket of this.buckets) { if (sumZero + this.bucketSize - bucket.size > x) { for (let i = 0; i < this.bucketSize; ++i) { if (!bucket.has(curr + i)) { if (sumZero === x) { bucket.add(curr + i); return [Math.floor((curr + i) / this.n), (curr + i) % this.n]; } sumZero++; } } } curr += this.bucketSize; sumZero += this.bucketSize - bucket.size; } return undefined; }; /** * @return {void} */ Solution.prototype.reset = function() { this.total = this.m * this.n; for (const bucket of this.buckets) { bucket.clear(); } }; /** * Your Solution object will be instantiated and called as such: * var obj = new Solution(m, n) * var param_1 = obj.flip() * obj.reset() */
golang 解法, 执行用时: 100 ms, 内存消耗: 7.2 MB, 提交时间: 2023-06-05 14:21:50
type Solution struct { m, n, total, bucketSize int buckets []map[int]bool } func Constructor(m, n int) Solution { total := m * n bucketSize := int(math.Sqrt(float64(total))) buckets := make([]map[int]bool, (total+bucketSize-1)/bucketSize) for i := range buckets { buckets[i] = map[int]bool{} } return Solution{m, n, total, bucketSize, buckets} } func (s *Solution) Flip() []int { x := rand.Intn(s.total) s.total-- sumZero, curr := 0, 0 for _, bucket := range s.buckets { if sumZero+s.bucketSize-len(bucket) > x { for i := 0; i < s.bucketSize; i++ { if !bucket[curr+i] { if sumZero == x { bucket[curr+i] = true return []int{(curr + i) / s.n, (curr + i) % s.n} } sumZero++ } } } curr += s.bucketSize sumZero += s.bucketSize - len(bucket) } return nil } func (s *Solution) Reset() { s.total = s.m * s.n for i := range s.buckets { s.buckets[i] = map[int]bool{} } } /** * Your Solution object will be instantiated and called as such: * obj := Constructor(m, n); * param_1 := obj.Flip(); * obj.Reset(); */
python3 解法, 执行用时: 1516 ms, 内存消耗: 18.8 MB, 提交时间: 2023-06-05 14:21:32
class Solution: def __init__(self, m: int, n: int): self.m, self.n = m, n self.total = m * n self.bucketSize = math.floor(math.sqrt(m * n)) self.buckets = [set() for _ in range(0, self.total, self.bucketSize)] def flip(self) -> List[int]: x = random.randint(0, self.total - 1) self.total -= 1 sumZero = 0 curr = 0 for i in range(len(self.buckets)): if sumZero + self.bucketSize - len(self.buckets[i]) > x: for j in range(self.bucketSize): if (curr + j) not in self.buckets[i]: if sumZero == x: self.buckets[i].add(curr + j) return [(curr + j) // self.n, (curr + j) % self.n] sumZero += 1 curr += self.bucketSize sumZero += self.bucketSize - len(self.buckets[i]) return [] def reset(self) -> None: self.total = self.m * self.n for i in range(len(self.buckets)): self.buckets[i].clear() # Your Solution object will be instantiated and called as such: # obj = Solution(m, n) # param_1 = obj.flip() # obj.reset()
golang 解法, 执行用时: 8 ms, 内存消耗: 5.5 MB, 提交时间: 2023-06-05 14:21:07
type Solution struct { m, n, total int mp map[int]int } func Constructor(m, n int) Solution { return Solution{m, n, m * n, map[int]int{}} } func (s *Solution) Flip() (ans []int) { x := rand.Intn(s.total) s.total-- if y, ok := s.mp[x]; ok { // 查找位置 x 对应的映射 ans = []int{y / s.n, y % s.n} } else { ans = []int{x / s.n, x % s.n} } if y, ok := s.mp[s.total]; ok { // 将位置 x 对应的映射设置为位置 total 对应的映射 s.mp[x] = y } else { s.mp[x] = s.total } return } func (s *Solution) Reset() { s.total = s.m * s.n s.mp = map[int]int{} } /** * Your Solution object will be instantiated and called as such: * obj := Constructor(m, n); * param_1 := obj.Flip(); * obj.Reset(); */
javascript 解法, 执行用时: 84 ms, 内存消耗: 46.9 MB, 提交时间: 2023-06-05 14:20:50
/** * @param {number} m * @param {number} n */ var Solution = function(m, n) { this.m = m; this.n = n; this.total = m * n; this.map = new Map(); }; /** * @return {number[]} */ Solution.prototype.flip = function() { const x = Math.floor(Math.random() * this.total); this.total--; // 查找位置 x 对应的映射 const idx = this.map.get(x) || x; // 将位置 x 对应的映射设置为位置 total 对应的映射 this.map.set(x, this.map.get(this.total) || this.total); return [Math.floor(idx / this.n), idx % this.n]; }; /** * @return {void} */ Solution.prototype.reset = function() { this.total = this.m * this.n; this.map.clear(); }; /** * Your Solution object will be instantiated and called as such: * var obj = new Solution(m, n) * var param_1 = obj.flip() * obj.reset() */
python3 解法, 执行用时: 76 ms, 内存消耗: 16.7 MB, 提交时间: 2023-06-05 14:20:15
class Solution: def __init__(self, m: int, n: int): self.m = m self.n = n self.total = m * n self.map = {} def flip(self) -> List[int]: x = random.randint(0, self.total - 1) self.total -= 1 # 查找位置 x 对应的映射 idx = self.map.get(x, x) # 将位置 x 对应的映射设置为位置 total 对应的映射 self.map[x] = self.map.get(self.total, self.total) return [idx // self.n, idx % self.n] def reset(self) -> None: self.total = self.m * self.n self.map.clear() # Your Solution object will be instantiated and called as such: # obj = Solution(m, n) # param_1 = obj.flip() # obj.reset()
java 解法, 执行用时: 21 ms, 内存消耗: 43.7 MB, 提交时间: 2023-06-05 14:19:55
class Solution { Map<Integer, Integer> map = new HashMap<>(); int m, n, total; Random rand = new Random(); public Solution(int m, int n) { this.m = m; this.n = n; this.total = m * n; } public int[] flip() { int x = rand.nextInt(total); total--; // 查找位置 x 对应的映射 int idx = map.getOrDefault(x, x); // 将位置 x 对应的映射设置为位置 total 对应的映射 map.put(x, map.getOrDefault(total, total)); return new int[]{idx / n, idx % n}; } public void reset() { total = m * n; map.clear(); } } /** * Your Solution object will be instantiated and called as such: * Solution obj = new Solution(m, n); * int[] param_1 = obj.flip(); * obj.reset(); */