小美在观察一棵美丽的无根树。
小团问小美:“小美,我考考你,如果我选一个点为根,你能不能找出子树大小不超过K的前提下,子树内最大值和最小值差最大的子树的根是哪个点?多个点的话你给我编号最小的那个点就行了。”
小美思索一番,说这个问题难不倒他。
小美在观察一棵美丽的无根树。
小团问小美:“小美,我考考你,如果我选一个点为根,你能不能找出子树大小不超过K的前提下,子树内最大值和最小值差最大的子树的根是哪个点?多个点的话你给我编号最小的那个点就行了。”
小美思索一番,说这个问题难不倒他。
第一行两个正整数N和K,表示全树有N个节点,要求子树大小不超过K。
第二行是N个正整数空格分隔,表示每个点的点权。以点编号从1到N的顺序给出点权。
接下来N-1行每行两个正整数表示哪两个点之间有边相连。
最后一行一个正整数root表示小团所选的根节点编号为root。
一行,一个正整数,含义如问题描述,输出在子树大小不超过K的前提下,子树内最大值和最小值差最大的子树的根的编号
5 2 1 3 2 4 5 1 2 2 3 3 4 4 5 3
2
对于30%的数据点有
对于100%的数据点有
各点上的权值有
,对于K有
import java.io.BufferedReader; import java.io.InputStreamReader; import java.io.IOException; import java.util.HashMap; import java.util.ArrayList; public class Main { static boolean[] visited; // 标记节点是否已经被访问 static int[] weight; // 节点权值 static int[] childNum; // 存储以节点i为根的树有多少个节点 static int[] max, min; // 存储以节点i为根的子树下的最大值和最小值 // 节点间的最大差值 static int maxDiff = -1; // 待求节点 static int node = -1; // 邻接表 static HashMap<Integer, ArrayList<Integer>> tree; public static void main(String[] args) throws IOException { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); String[] temp = br.readLine().trim().split(" "); int n = Integer.parseInt(temp[0]); int k = Integer.parseInt(temp[1]); temp = br.readLine().trim().split(" "); weight = new int[n + 1]; for(int i = 1; i <= n; i++) weight[i] = Integer.parseInt(temp[i - 1]); int x, y; // 构建树图的邻接表 tree = new HashMap<>(); ArrayList<Integer> list; for(int i = 1; i <= n - 1; i++){ temp = br.readLine().trim().split(" "); x = Integer.parseInt(temp[0]); y = Integer.parseInt(temp[1]); if(tree.containsKey(x)){ list = tree.get(x); list.add(y); tree.put(x, list); }else{ list = new ArrayList<>(); list.add(y); tree.put(x, list); } if(tree.containsKey(y)){ list = tree.get(y); list.add(x); tree.put(y, list); }else{ list = new ArrayList<>(); list.add(x); tree.put(y, list); } } int root = Integer.parseInt(br.readLine().trim()); visited = new boolean[n + 1]; max = new int[n + 1]; min = new int[n + 1]; childNum = new int[n + 1]; dfs(root, k); System.out.println(node); } // 求取节点parent下子节点的最值 private static void dfs(int parent, int k) { visited[parent] = true; // 初始化parent下的最值为parent的节点权重 max[parent] = weight[parent]; min[parent] = weight[parent]; // 初始情况下,该子树只有一个节点 childNum[parent] = 1; for(int i = 0; i < tree.get(parent).size(); i++){ int child = tree.get(parent).get(i); if(!visited[child]){ // 没访问过这个孩子节点就进行深搜 dfs(child, k); max[parent] = Math.max(max[parent], max[child]); min[parent] = Math.min(min[parent], min[child]); childNum[parent] += childNum[child]; } } if(childNum[parent] <= k && max[parent] - min[parent] >= maxDiff){ // 以parent为根节点的子树满足节点数小于等于k,且最大差值大于等于目前最大 if(max[parent] - min[parent] > maxDiff){ // 大于了直接更新,等于的话需要考虑哪个根节点的编号小 node = parent; maxDiff = max[parent] - min[parent]; }else{ // 如果node还没有赋值,就直接赋值为当前节点,否则取满足要求的节点中编号最小的 node = node == -1? parent: Math.min(node, parent); } } } }
广搜构造树,dfs后序遍历以使用子树中的结果来计算,动态规划思想(保存子树中的最大最小值来计算当前整树的最大最小值)。
本题算法要求中等,但是操作比较繁杂。
from collections import defaultdict, deque class TreeNode: def __init__(self, val=0, idx=0): self.val = val self.idx = idx self.child = dict() N, K = map(int, input().split()) val_list = list(map(int, input().split())) edges = defaultdict(set) for _ in range(N-1): i, j = map(int, input().split()) edges[i].add(j) edges[j].add(i) ro_idx = int(input()) root = TreeNode(val=val_list[ro_idx-1], idx=ro_idx) q = deque([root]) seen = set([ro_idx]) while q: node = q.popleft() cur_id = node.idx for i in edges[cur_id]: if i not in seen: new = TreeNode(val=val_list[i-1], idx=i) node.child[i] = new q.append(new) seen.add(i) message = [-1, N+1] # 0:最大差值 1:最小编号 def solve(node): l = r = node.val cnt = 1 for i in node.child: li, ri, cnti = solve(node.child[i]) l = min(li, l) r = max(ri, r) cnt += cnti if (r - l > message[0] or (r - l == message[0] and node.idx < message[1])) and cnt <= K: message[0] = r-l message[1] = node.idx return l, r, cnt solve(root) print(message[1])
import java.util.*; import java.io.*; class Node { // 编号 public int number; // 点权 public int weight; // 子节点 public List<Node> children; // 子树的最小值 public int min; // 子树的最大值 public int max; // 子树的节点数 public int count; public Node(int number, int weight) { this.number = number; this.weight = weight; this.children = new ArrayList<>(); this.min = weight; this.max = weight; this.count = 1; } } public class Main { private static BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); // 读取一个整数 private static int readInt() throws IOException { int num = 0; char ch; while ((ch = (char)reader.read()) != ' ' && ch != '\n') { num = num * 10 + (ch - '0'); } return num; } // 节点数目 private static int n; /// 节点数目不超过 private static int k; // 权重 private static int[] weights; // 邻接表 private static Map<Integer, Set<Integer>> friends = new HashMap<>(); // 最大差值 private static int maxDiff = Integer.MIN_VALUE; // 最大差值的节点 private static int result = Integer.MAX_VALUE; // 根据邻接表构建树(后序遍历),构造后判断是否符合题意 private static Node create(int rootNumber) { Node root = new Node(rootNumber, weights[rootNumber]); // 构造子树 for (int childNumber : friends.get(rootNumber)) { Set<Integer> set = friends.get(childNumber); if (set != null) { set.remove(rootNumber); } Node child = create(childNumber); root.children.add(child); root.min = child.min < root.min ? child.min : root.min; root.max = child.max > root.max ? child.max : root.max; root.count += child.count; } // 判断是否满足题意 if (root.count <= k) { int diff = root.max - root.min; // 大于 if (diff > maxDiff) { maxDiff = root.max - root.min; result = rootNumber; } // 等于 else if (diff == maxDiff) { result = rootNumber < result ? rootNumber : result; } } return root; } public static void main(String[] args) throws IOException { n = readInt(); k = readInt(); // 权重 weights = new int[n + 1]; for (int i = 1; i <= n; ++i) { weights[i] = readInt(); } // 邻接表 for (int i = 1; i < n; ++i) { int x = readInt(), y = readInt(); friends.computeIfAbsent(x, k -> new HashSet<>()).add(y); friends.computeIfAbsent(y, k -> new HashSet<>()).add(x); } // 构建子树 create(readInt()); System.out.println(result); } }
//java写法 import java.util.Arrays; import java.util.Scanner; public class Main { static int [] h = new int[400010],ne = new int[400010], e = new int[400010]; static int [] w = new int[400010]; static int maxValue = Integer.MIN_VALUE,ret; static int n,k; static int idx = 0; static void add(int a,int b) { e[idx] = b; ne[idx] = h[a]; h[a] = idx++; } //ans[0] 表示以u为根的最小值 ans[1]表示最大值 ans[2]表示节点个数 u表示根 p表示父节点 防止扩展时向上扩展 static int [] dfs (int u, int p) { boolean hasChild = false; int [] ans = new int[3]; ans[0] = w[u]; ans[1] = w[u]; ans[2] = 1; //枚举子节点 更新数组 for (int i = h[u];i != -1;i = ne[i]) { int j = e[i]; if (j == p) continue; hasChild = true; int [] childAns = dfs(j, u); ans[0] = Math.min(ans[0], childAns[0]); ans[1] = Math.max(ans[1], childAns[1]); ans[2] += childAns[2]; } if (ans[2] <= k) { if (Math.abs(ans[0] - ans[1]) > maxValue) { maxValue = Math.abs(ans[0] - ans[1]); ret = u; } else if (Math.abs(ans[0] - ans[1]) == maxValue) { if (u < ret) { ret = u; } } } if (!hasChild) { return new int[]{w[u], w[u], 1}; } return ans; } public static void main(String[] args) { Arrays.fill(h, - 1); Scanner sc = new Scanner(System.in); n = sc.nextInt(); k = sc.nextInt(); for (int i = 1;i <= n;i ++) w[i] = sc.nextInt(); for (int i = 1;i < n;i ++) { int a = sc.nextInt(); int b = sc.nextInt(); add(a, b); add(b, a); } int u = sc.nextInt(); dfs(u, -1); System.out.println(ret); } }
import collections class TreeNode: def __init__(self, order, val) -> None: self.order = order self.val = val self.childs = [] pass n, k = map(lambda x: int(x), input().split()) d = {} val_list = list(map(lambda x: int(x), input().split())) for i, val in enumerate(val_list): d[i+1] = TreeNode(i+1, val) order_childs = [[] for _ in range(n+1)] for _ in range(n-1): u, v = map(lambda x: int(x), input().split()) order_childs[u].append(v) order_childs[v].append(u) root_order = int(input()) vis = [False]*(n+1) def create_tree(): que = collections.deque() que.append(root_order) vis[root_order] = True while que: par_order = que.popleft() nxt_orders = order_childs[par_order] for order in nxt_orders: if not vis[order]: d[par_order].childs.append(d[order]) vis[order] = True que.append(order) create_tree() # def print_tree(root): # # print(root.order) # for ch in root.childs: # print(str(root.order)+' -> ' +str(ch.order)) # print_tree(ch) # print_tree(d[root_order]) # exit() max_data = float('-inf') ans_node = -1 def dfs(node: TreeNode): global max_data global ans_node min_v = node.val max_v = node.val node_num = 1 for ch in node.childs: ch_num, ch_minv, ch_maxv = dfs(ch) node_num += ch_num min_v = min(ch_minv, min_v) max_v = max(ch_maxv, max_v) data = max_v-min_v if data>=max_data and node_num<=k: if data>max_data: ans_node = node.order max_data = data else: if ans_node>node.order: ans_node = node.order return node_num, min_v, max_v root = d[root_order] dfs(root) print(ans_node)
import java.util.*; public class Main{ //记录当前子树中所有权值的最大值 private static int[] max; //记录当前子树中所有权值的最小值 private static int[] min; //记录每个根节点的权值 private static int[] weight; //记录每个根节点中所有的节点个数,包括当前节点 private static int[] node_count; //记录差值的最大值 private static int maxf=0; //记录符合题目的根节点 private static int small_root=-1; //记录已经被访问过的根节点 private static boolean[] is_visted; //生成树的关系图 private static List<Integer>[] tree; public static void main(String[] arg){ //初始化数据 Scanner sc=new Scanner(System.in); int n=sc.nextInt(); int k=sc.nextInt(); is_visted =new boolean[n]; weight=new int[n]; max=new int[n]; min=new int[n]; node_count=new int[n]; tree=new ArrayList[n]; for(int i=0;i<n;i++){ weight[i]=sc.nextInt(); } for(int i=0;i<n-1;i++){ int x=sc.nextInt()-1; int y=sc.nextInt()-1; if(tree[x]==null){ ArrayList<Integer> list=new ArrayList<>(); list.add(y); tree[x]=list; }else{ tree[x].add(y); } if(tree[y]==null){ ArrayList<Integer> list=new ArrayList<>(); list.add(x); tree[y]=list; }else{ tree[y].add(x); } } //完成初始化工作,把选择作为根的节点标记已访问 int root=sc.nextInt()-1; is_visted[root]=true; //遍历根节点的所有子树 for(Integer seed:tree[root]){ dfs(seed,k); } System.out.print(small_root+1); } public static void dfs(int seed,int k){ //每进到一颗子树就打上标记 is_visted[seed]=true; //一开始这棵树的最大值和最小值都是自己 max[seed]=weight[seed]; min[seed]=weight[seed]; //一开始节点个数只有自己 node_count[seed]=1; //开始遍历seed的子树 for(Integer childen : tree[seed]){ //如果标记访问过的,说明是children的父节点或父父节点等,跳过 if(is_visted[childen]){ continue; } //从这里开始递归,递归完成后max、min和node_count中都包含了子树的信息 dfs(childen,k); //最终的节点个数就是加完所有子树的节点个数 node_count[seed]+=node_count[childen]; //和每个子树的max、min作比较,找出seed的max、min max[seed]=Math.max(max[seed],max[childen]); min[seed]=Math.min(min[seed],min[childen]); } //最终比较当前seed的节点个数是否满足题目要求,不满足直接返回 if(node_count[seed]<=k){ //满足的话再看max-min是否大于之前记录的满足题目要求的最大差值maxf if(max[seed]-min[seed]>maxf){ //大于的话直接更新当前符合题目的节点,和最大差值 small_root=seed; maxf=max[seed]-min[seed]; } //如果相等的话,看看small_root是不是第一次赋值 ,是的话直接赋值,不是的话比较两个small_root,更新最小值 if(max[seed]-min[seed]==maxf){ small_root=small_root==-1?seed:Math.min(small_root,seed); } } } }
import java.util.*; public class Main{ static int max=0,index=0;//最大差值,最大差值对应的最小编号 static int[][] dp;//dp[i][0/1/2]i为根对应的最大值、最小值,节点数量 static int k=0; public static void main(String[] argvs){ Scanner in=new Scanner(System.in); int n=in.nextInt(); k=in.nextInt(); int[] v=new int[n+1]; dp=new int[n+1][3]; for(int i=1;i<=n;i++) { v[i]=in.nextInt(); //dp[i][1]=Integer.MAX_VALUE; } Map<Integer,Set<Integer>> g=new HashMap<>(); for(int i=0;i<n-1;i++){ int x=in.nextInt(),y=in.nextInt(); Set<Integer> s1=g.computeIfAbsent(x,(a)->new HashSet<>()); Set<Integer> s2=g.computeIfAbsent(y,(a)->new HashSet<>()); s1.add(y); s2.add(x); } int r=in.nextInt(); dfs(r,g,v); System.out.print(index); } public static void dfs(int cur,Map<Integer,Set<Integer>> g,int[] v){ dp[cur][0]=v[cur]; dp[cur][1]=v[cur]; dp[cur][2]=1; if(!g.containsKey(cur) || g.get(cur).size()==0) return; Set<Integer> son=g.get(cur); for(int s:son){ Set<Integer> tmp=g.get(s); tmp.remove(cur); dfs(s,g,v); dp[cur][0]=Math.max(dp[cur][0],dp[s][0]); dp[cur][1]=Math.min(dp[cur][1],dp[s][1]); dp[cur][2]+=dp[s][2]; } int dis=dp[cur][0]-dp[cur][1]; if(dp[cur][2]<=k){ if(dis>max){ max=dis; index=cur; }else if(max==dis){ index=Math.min(index,cur); } } } }
import java.util.*; import java.io.*; public class Main{ static int poor ; //记录在递归过程中出现的符合要求最大差值 static int index ;//最大差值的下标 public static void main(String[] args) throws IOException{ BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); String[] one = br.readLine().trim().split(" "); int n = Integer.parseInt(one[0]); int k = Integer.parseInt(one[1]); String[] value = br.readLine().trim().split(" "); Node[] nodes = new Node[n]; //节点数组 //构造节点 for(int i = 0; i < n ; i++){ int num = Integer.parseInt(value[i]); nodes[i] = new Node(num, i + 1 ,new ArrayList<Node>()); } //把节点连接起来 for(int i = 0; i < n - 1; i++){ String[] line = br.readLine().trim().split(" "); int n1 = Integer.parseInt(line[0]); int n2 = Integer.parseInt(line[1]); nodes[n1 - 1].list.add(nodes[n2 - 1]); nodes[n2 - 1].list.add(nodes[n1 - 1]); } poor = 0; index = -1; Node node = nodes[Integer.parseInt(br.readLine()) - 1];//根节点 //遍历根节点 for(Node nn : node.list){ dfs(nn, node, k); } System.out.println(index); } public static Info dfs(Node node,Node father,int k){ int maxS = node.value; int minS = node.value; int flo = 1;//node节点中个数,flo = 1 是 node这个1 for(Node n : node.list){//遍历连接点 if(n != father){//父节点不遍历 Info info = dfs(n, node, k); flo += info.flo;//子树中个数相加 maxS = Math.max(maxS, info.max); minS = Math.min(minS, info.min); } } //如果节点个数小于等于k,并且最大值和最小值差比最大差大,那么替换 if(flo <= k && poor < maxS - minS){ poor = maxS - minS; index = node.index; } //如果节点个数小于等于k,并且最大值和最小值差等于最大差而且下标更小,那么替换下标 if( flo <= k && poor == maxS - minS && node.index < index){ index = node.index; } return new Info(maxS, minS, flo); } } //子树返回的信息 class Info{ int max;//子树中的最大值 int min;//子树中的最小值 int flo;//子树中节点个数(在确定根节点的情况下) public Info(int max, int min, int flo){ this.max = max; this.min = min; this.flo = flo; } } //节点类 class Node{ int value;//权值 int index;//节点的下标 List<Node> list;//连接的点 public Node(int value, int index, List<Node> list){ this.value = value; this.index = index; this.list = list; } }
#include <bits/stdc++.h> using namespace std; vector<int> arr; vector<vector<int>> G; vector<int> MAX; vector<int> MIN; int res = 10000000; int dif = 0; vector<bool> visit; int K; int dfs(int r){ int k = 1; visit[r] = true; MIN[r] = MAX[r] = arr[r]; for(int x : G[r]){ if(!visit[x]){ k += dfs(x); if(MIN[r] > MIN[x]) MIN[r] = MIN[x]; if(MAX[r] < MAX[x]) MAX[r] = MAX[x]; } } int d = MAX[r] - MIN[r]; if(k <= K){ if(dif == d){ if(res > r) res = r; } if(dif < d){ dif = d; res = r; } } visit[r] = false; return k; } int main(){ int N; cin >> N >> K; arr.resize(N); for(int i = 0; i < N; i++) cin >> arr[i]; int x, y; G.resize(N); visit.resize(N, false); MAX.resize(N); MIN.resize(N); for(int i = 0; i < N - 1; ++i){ cin >> x >> y; G[x - 1].push_back(y - 1); G[y - 1].push_back(x - 1); } int r; cin >> r; dfs(r - 1); cout << res + 1; }
import java.util.*; import java.io.*; public class Main { static int[] h; static int[] e; static int[] ne; static int[] w; static int n ; static int k ; static int idx = 0; static int node = -1; static int value = Integer.MIN_VALUE; static int[] min; static int[] max; static boolean[] vs; public static void add(int x, int y){ ne[idx] = h[x]; e[idx] = y; h[x] = idx++; } public static void main(String[] args) throws IOException { BufferedReader buf = new BufferedReader(new InputStreamReader(System.in)); String[] line = buf.readLine().split(" "); n = Integer.parseInt(line[0]); k = Integer.parseInt(line[1]); h = new int[n+10]; e= new int[2 * n+10]; ne = new int[2 * n+10]; min = new int[n+10]; max = new int[n+10]; vs = new boolean[n+10]; w = new int[n+10]; Arrays.fill(h,-1); line = buf.readLine().split(" "); for (int i = 1; i <= n; i++){ w[i] = Integer.parseInt(line[i-1]); min[i] = w[i]; max[i] = w[i]; } for (int i = 0; i < n - 1; i++){ line = buf.readLine().split(" "); int x = Integer.parseInt(line[0]); int y = Integer.parseInt(line[1]); add(x,y); add(y,x); } String r = buf.readLine(); int root =Integer.parseInt(r); getMinMax(root); System.out.println(node); } public static int getMinMax(int x){ vs[x] = true; int res = 1; int p = w[x]; int q = w[x]; int y = x; for (int i = h[x]; i != -1; i = ne[i]){ int child = e[i]; if (!vs[child]){ res +=getMinMax(child); p = Math.min(p,min[child]); q = Math.max(q,max[child]); } } if (res <= k && value < q - p){ node = y; value = q - p; }else if(res <= k && value == q - p){ if(node == -1 || y < node){ node = y; value = q - p; } } min[x] = p; max[x] = q; return res; } }