列表

详情


1825. 求出 MK 平均值

给你两个整数 m 和 k ,以及数据流形式的若干整数。你需要实现一个数据结构,计算这个数据流的 MK 平均值 。

MK 平均值 按照如下步骤计算:

  1. 如果数据流中的整数少于 m 个,MK 平均值 为 -1 ,否则将数据流中最后 m 个元素拷贝到一个独立的容器中。
  2. 从这个容器中删除最小的 k 个数和最大的 k 个数。
  3. 计算剩余元素的平均值,并 向下取整到最近的整数 。

请你实现 MKAverage 类:

 

示例 1:

输入:
["MKAverage", "addElement", "addElement", "calculateMKAverage", "addElement", "calculateMKAverage", "addElement", "addElement", "addElement", "calculateMKAverage"]
[[3, 1], [3], [1], [], [10], [], [5], [5], [5], []]
输出:
[null, null, null, -1, null, 3, null, null, null, 5]

解释:
MKAverage obj = new MKAverage(3, 1); 
obj.addElement(3);        // 当前元素为 [3]
obj.addElement(1);        // 当前元素为 [3,1]
obj.calculateMKAverage(); // 返回 -1 ,因为 m = 3 ,但数据流中只有 2 个元素
obj.addElement(10);       // 当前元素为 [3,1,10]
obj.calculateMKAverage(); // 最后 3 个元素为 [3,1,10]
                          // 删除最小以及最大的 1 个元素后,容器为 [3]
                          // [3] 的平均值等于 3/1 = 3 ,故返回 3
obj.addElement(5);        // 当前元素为 [3,1,10,5]
obj.addElement(5);        // 当前元素为 [3,1,10,5,5]
obj.addElement(5);        // 当前元素为 [3,1,10,5,5,5]
obj.calculateMKAverage(); // 最后 3 个元素为 [5,5,5]
                          // 删除最小以及最大的 1 个元素后,容器为 [5]
                          // [5] 的平均值等于 5/1 = 5 ,故返回 5

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class MKAverage { public: MKAverage(int m, int k) { } void addElement(int num) { } int calculateMKAverage() { } }; /** * Your MKAverage object will be instantiated and called as such: * MKAverage* obj = new MKAverage(m, k); * obj->addElement(num); * int param_2 = obj->calculateMKAverage(); */

golang 解法, 执行用时: 268 ms, 内存消耗: 40.3 MB, 提交时间: 2023-01-18 11:06:28

type MKAverage struct {
	lo, mid, hi  *redblacktree.Tree
	q            []int
	m, k, s      int
	size1, size3 int
}

func Constructor(m int, k int) MKAverage {
	lo := redblacktree.NewWithIntComparator()
	mid := redblacktree.NewWithIntComparator()
	hi := redblacktree.NewWithIntComparator()
	return MKAverage{lo, mid, hi, []int{}, m, k, 0, 0, 0}
}

func (this *MKAverage) AddElement(num int) {
	merge := func(rbt *redblacktree.Tree, key, value int) {
		if v, ok := rbt.Get(key); ok {
			nxt := v.(int) + value
			if nxt == 0 {
				rbt.Remove(key)
			} else {
				rbt.Put(key, nxt)
			}
		} else {
			rbt.Put(key, value)
		}
	}

	if this.lo.Empty() || num <= this.lo.Right().Key.(int) {
		merge(this.lo, num, 1)
		this.size1++
	} else if this.hi.Empty() || num >= this.hi.Left().Key.(int) {
		merge(this.hi, num, 1)
		this.size3++
	} else {
		merge(this.mid, num, 1)
		this.s += num
	}
	this.q = append(this.q, num)
	if len(this.q) > this.m {
		x := this.q[0]
		this.q = this.q[1:]
		if _, ok := this.lo.Get(x); ok {
			merge(this.lo, x, -1)
			this.size1--
		} else if _, ok := this.hi.Get(x); ok {
			merge(this.hi, x, -1)
			this.size3--
		} else {
			merge(this.mid, x, -1)
			this.s -= x
		}
	}
	for ; this.size1 > this.k; this.size1-- {
		x := this.lo.Right().Key.(int)
		merge(this.lo, x, -1)
		merge(this.mid, x, 1)
		this.s += x
	}
	for ; this.size3 > this.k; this.size3-- {
		x := this.hi.Left().Key.(int)
		merge(this.hi, x, -1)
		merge(this.mid, x, 1)
		this.s += x
	}
	for ; this.size1 < this.k && !this.mid.Empty(); this.size1++ {
		x := this.mid.Left().Key.(int)
		merge(this.mid, x, -1)
		this.s -= x
		merge(this.lo, x, 1)
	}
	for ; this.size3 < this.k && !this.mid.Empty(); this.size3++ {
		x := this.mid.Right().Key.(int)
		merge(this.mid, x, -1)
		this.s -= x
		merge(this.hi, x, 1)
	}
}

func (this *MKAverage) CalculateMKAverage() int {
	if len(this.q) < this.m {
		return -1
	}
	return this.s / (this.m - 2*this.k)
}

/**
 * Your MKAverage object will be instantiated and called as such:
 * obj := Constructor(m, k);
 * obj.AddElement(num);
 * param_2 := obj.CalculateMKAverage();
 */

python3 解法, 执行用时: 1156 ms, 内存消耗: 45.2 MB, 提交时间: 2023-01-18 11:05:35

from sortedcontainers import SortedList

class MKAverage:
    def __init__(self, m: int, k: int):
        self.m,self.k = m,k
        self.d = deque([])
        self.s = SortedList()
        self.sums = 0
        
    def addElement(self, num: int) -> None:
        d,s,k,m = self.d,self.s,self.k,self.m
        if len(d) < 2*k:
            d.append(num)
            s.add(num)
            return
        elif len(d) < m:
            d.append(num)
        elif len(d) == m:
            d.append(num)
            p = d.popleft()
            aa,bb = s[k-1],s[-k]
            if p <= aa:self.sums -= s[k]
            elif p >= bb:self.sums -= s[-k-1]
            else:self.sums -= p
            s.remove(p)
        aa,bb = s[k-1],s[-k]
        if num <= aa:self.sums += aa
        elif num > bb:self.sums += bb
        else:self.sums += num
        s.add(num)
        
    def calculateMKAverage(self) -> int:
        if len(self.d) < self.m:return -1
        return self.sums//(self.m-self.k*2)

# Your MKAverage object will be instantiated and called as such:
# obj = MKAverage(m, k)
# obj.addElement(num)
# param_2 = obj.calculateMKAverage()

python3 解法, 执行用时: 8584 ms, 内存消耗: 99.9 MB, 提交时间: 2023-01-18 11:04:57

class Node:
    def __init__(self):
        self.left = self.right = None
        self.val = 0
        
class segtree:
    def __init__(self):
        self.root = Node()
        self.build(self.root,0,10**5)
    def build(self,node,l,r):
        if l == r:return 
        mid = (l+r)//2
        node.left = Node()
        node.right = Node()
        self.build(node.left,l,mid)
        self.build(node.right,mid+1,r)
    def add(self,node,idx,val,l=0,r=10**5):
        if l == r:
            node.val += val*idx
            return
        mid = (l+r)//2
        if idx <= mid:
            self.add(node.left,idx,val,l,mid)
        else:
            self.add(node.right,idx,val,mid+1,r)
        node.val += val*idx
    def sums(self,node,left,right,l=0,r=10**5):
        if l == left and r == right:return node.val
        mid = (l+r)//2
        res = 0
        if left <= mid:
            res += self.sums(node.left,left,min(right,mid),l,mid)
        if right >= mid+1:
            res += self.sums(node.right,max(left,mid+1),right,mid+1,r)
        return res
        
from sortedcontainers import SortedList
class MKAverage:
    def __init__(self, m: int, k: int):
        self.m = m
        self.k = k
        self.d = deque([])
        self.s = SortedList()
        self.t = segtree()
        self.root = self.t.root

    def addElement(self, num: int) -> None:
        self.d.append(num)
        self.s.add(num)
        self.t.add(self.root,num,1)
        if len(self.d) > self.m:
            p = self.d.popleft()
            self.s.remove(p)
            self.t.add(self.root,p,-1)
    def calculateMKAverage(self) -> int:
        if len(self.d) < self.m:return -1
        m,d,s,t,root,k = self.m,self.d,self.s,self.t,self.root,self.k
        lens = m-2*k
        res = 0
        aa,dd = s[k-1],s[-k]
        if aa == dd:return aa
        bb = s.bisect_right(aa)
        res += (bb-k)*aa
        cc = s.bisect_left(dd)
        res += (m-cc-k)*dd
        res += t.sums(root,aa+1,dd-1)
        return res//lens

# Your MKAverage object will be instantiated and called as such:
# obj = MKAverage(m, k)
# obj.addElement(num)
# param_2 = obj.calculateMKAverage()

java 解法, 执行用时: 51 ms, 内存消耗: 91.8 MB, 提交时间: 2023-01-18 11:04:01

class MKAverage {
    private int m, k;
    private Queue<Integer> q;
    private TreeMap<Integer, Integer> s1;
    private TreeMap<Integer, Integer> s2;
    private TreeMap<Integer, Integer> s3;
    private int size1, size2, size3;
    private long sum2;

    public MKAverage(int m, int k) {
        this.m = m;
        this.k = k;
        this.q = new ArrayDeque<Integer>();
        this.s1 = new TreeMap<Integer, Integer>();
        this.s2 = new TreeMap<Integer, Integer>();
        this.s3 = new TreeMap<Integer, Integer>();
        this.size1 = 0;
        this.size2 = 0;
        this.size3 = 0;
        this.sum2 = 0;
    }

    public void addElement(int num) {
        q.offer(num);
        if (q.size() <= m) {
            s2.put(num, s2.getOrDefault(num, 0) + 1);
            size2++;
            sum2 += num;
            if (q.size() == m) {
                while (size1 < k) {
                    int firstNum = s2.firstKey();
                    s1.put(firstNum, s1.getOrDefault(firstNum, 0) + 1);
                    size1++;
                    sum2 -= firstNum;
                    s2.put(firstNum, s2.get(firstNum) - 1);
                    if (s2.get(firstNum) == 0) {
                        s2.remove(firstNum);
                    }
                    size2--;
                }
                while (size3 < k) {
                    int lastNum = s2.lastKey();
                    s3.put(lastNum, s3.getOrDefault(lastNum, 0) + 1);
                    size3++;
                    sum2 -= lastNum;
                    s2.put(lastNum, s2.get(lastNum) - 1);
                    if (s2.get(lastNum) == 0) {
                        s2.remove(lastNum);
                    }
                    size2--;
                }
            }
            return;
        }

        if (num < s1.lastKey()) {
            s1.put(num, s1.getOrDefault(num, 0) + 1);
            int lastNum = s1.lastKey();
            s2.put(lastNum, s2.getOrDefault(lastNum, 0) + 1);
            size2++;
            sum2 += lastNum;
            s1.put(lastNum, s1.get(lastNum) - 1);
            if (s1.get(lastNum) == 0) {
                s1.remove(lastNum);
            }
        } else if (num > s3.firstKey()) {
            s3.put(num, s3.getOrDefault(num, 0) + 1);
            int firstNum = s3.firstKey();
            s2.put(firstNum, s2.getOrDefault(firstNum, 0) + 1);
            size2++;
            sum2 += firstNum;
            s3.put(firstNum, s3.get(firstNum) - 1);
            if (s3.get(firstNum) == 0) {
                s3.remove(firstNum);
            }
        } else {
            s2.put(num, s2.getOrDefault(num, 0) + 1);
            size2++;
            sum2 += num;
        }

        int x = q.poll();
        if (s1.containsKey(x)) {
            s1.put(x, s1.get(x) - 1);
            if (s1.get(x) == 0) {
                s1.remove(x);
            }
            int firstNum = s2.firstKey();
            s1.put(firstNum, s1.getOrDefault(firstNum, 0) + 1);
            sum2 -= firstNum;
            s2.put(firstNum, s2.get(firstNum) - 1);
            if (s2.get(firstNum) == 0) {
                s2.remove(firstNum);
            }
            size2--;
        } else if (s3.containsKey(x)) {
            s3.put(x, s3.get(x) - 1);
            if (s3.get(x) == 0) {
                s3.remove(x);
            }
            int lastNum = s2.lastKey();
            s3.put(lastNum, s3.getOrDefault(lastNum, 0) + 1);
            sum2 -= lastNum;
            s2.put(lastNum, s2.get(lastNum) - 1);
            if (s2.get(lastNum) == 0) {
                s2.remove(lastNum);
            }
            size2--;
        } else {
            s2.put(x, s2.get(x) - 1);
            if (s2.get(x) == 0) {
                s2.remove(x);
            }
            size2--;
            sum2 -= x;
        }
    }

    public int calculateMKAverage() {
        if (q.size() < m) {
            return -1;
        }
        return (int) (sum2 / (m - 2 * k));
    }
}

/**
 * Your MKAverage object will be instantiated and called as such:
 * MKAverage obj = new MKAverage(m, k);
 * obj.addElement(num);
 * int param_2 = obj.calculateMKAverage();
 */

cpp 解法, 执行用时: 860 ms, 内存消耗: 143.6 MB, 提交时间: 2023-01-18 11:03:27

// 三个有序集合
class MKAverage {
private:
    int m, k;
    queue<int> q;
    multiset<int> s1, s2, s3;
    long long sum2;
public:
    MKAverage(int m, int k) : m(m), k(k) {
        sum2 = 0;
    }

    void addElement(int num) {
        q.push(num);
        if (q.size() <= m) {
            s2.insert(num);
            sum2 += num;
            if (q.size() == m) {
                while (s1.size() < k) {
                    s1.insert(*s2.begin());
                    sum2 -= *s2.begin();
                    s2.erase(s2.begin());
                }
                while (s3.size() < k) {
                    s3.insert(*s2.rbegin());
                    sum2 -= *s2.rbegin();
                    s2.erase(prev(s2.end()));
                }
            }
            return;
        }

        if (num < *s1.rbegin()) {
            s1.insert(num);
            s2.insert(*s1.rbegin());
            sum2 += *s1.rbegin();
            s1.erase(prev(s1.end()));
        } else if (num > *s3.begin()) {
            s3.insert(num);
            s2.insert(*s3.begin());
            sum2 += *s3.begin();
            s3.erase(s3.begin());
        } else {
            s2.insert(num);
            sum2 += num;
        }

        int x = q.front();
        q.pop();
        if (s1.count(x) > 0) {
            s1.erase(s1.find(x));
            s1.insert(*s2.begin());
            sum2 -= *s2.begin();
            s2.erase(s2.begin());
        } else if (s3.count(x) > 0) {
            s3.erase(s3.find(x));
            s3.insert(*s2.rbegin());
            sum2 -= *s2.rbegin();
            s2.erase(prev(s2.end()));
        } else {
            s2.erase(s2.find(x));
            sum2 -= x;
        }
    }

    int calculateMKAverage() {
        if (q.size() < m) {
            return -1;
        }
        return sum2 / (m - 2 * k);
    }
};

cpp 解法, 执行用时: 340 ms, 内存消耗: 123.7 MB, 提交时间: 2023-01-18 11:02:27

// Splay 伸展树
#define INF 0x3f3f3f3f
#define ll long long

const int N = 1e5 + 5;

struct Node
{
    int s[2], p, v;
    int sz;
    ll sum;

    void init(int _v, int _p)
    {
        sum = v = _v, p = _p;
        sz = 1;
        s[0] = s[1] = 0;
    }
}tr[N];

int root, idx;

void pushup(int x)
{
    tr[x].sum = (ll)tr[tr[x].s[0]].sum + tr[tr[x].s[1]].sum + tr[x].v;
    tr[x].sz = tr[tr[x].s[0]].sz + tr[tr[x].s[1]].sz + 1;
}

void rotate(int x)  // 旋转
{
    int y = tr[x].p, z = tr[y].p;
    int k = tr[y].s[1] == x;
    tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
    tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
    tr[x].s[k ^ 1] = y, tr[y].p = x;
    pushup(y), pushup(x);
}

void splay(int x, int k)  // splay操作
{
    while(tr[x].p != k)
    {
        int y = tr[x].p, z = tr[y].p;
        if(z != k)
        {
            if((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
            else rotate(y);
        }
        rotate(x);
    }
    if(!k) root = x;
}

void insert(int v)
{
    int u = root, p = 0;
    while(u) p = u, u = tr[u].s[v > tr[u].v];

    u = ++idx;
    tr[u].init(v, p);
    if(p) tr[p].s[v > tr[p].v] = u;
    splay(u, 0);
    return;
}

void get_v(int v) // >=的最小数
{
    int u = root, res;
    while(u)
    {
        if(v <= tr[u].v) res = u, u = tr[u].s[0];
        else u = tr[u].s[1];
    }
    splay(res, 0);
}

int getpn(int v, int f) // 0 prev 1 next
{
    get_v(v);
    int u = root;
    if(tr[u].v < v && !f) return u;
    if(tr[u].v > v && f) return u;
    u = tr[u].s[f];
    while(tr[u].s[f ^ 1]) u = tr[u].s[f ^ 1];
    return u;
}

void del(int v)
{
    int l = getpn(v, 0), r = getpn(v, 1);
    splay(l, 0), splay(r, l);
    tr[r].s[0] = 0;
}

int get_id_by_rank(int k)
{
    int u = root;
    while(1)
    {
        if(tr[tr[u].s[0]].sz >= k) u = tr[u].s[0];
        else if(tr[tr[u].s[0]].sz + 1 >= k) return u;
        else k -= tr[tr[u].s[0]].sz + 1, u = tr[u].s[1];
    }
}


class MKAverage
{
public:
    int m, k;
    queue<int> q;

    MKAverage(int m, int k) : m(m), k(k)
    {
        idx = root = 0;
        insert(INF), insert(-INF);
    }

    void addElement(int num)
    {
        q.push(num);
        insert(num);
        if(q.size() > m)
        {
            del(q.front());
            q.pop();
        }
    }

    int calculateMKAverage()
    {
        if(q.size() < m) return -1;
        else
        {
            int l = get_id_by_rank(k + 1), r = get_id_by_rank(m - k + 2);
            splay(l, 0), splay(r, l);
            return tr[tr[r].s[0]].sum / tr[tr[r].s[0]].sz;
        }
    }
};

上一题