题解 | #矩阵乘法计算量估算#
矩阵乘法计算量估算
https://www.nowcoder.com/practice/15e41630514445719a942e004edc0a5b
栈的解法
思路看注释就够了。
这里只说一下新建的类,
矩阵的行列---MatrixInfo,包括行、列,
计算量---Amount,包括矩阵、计算量,用于矩阵计算的返回值,
栈的元素类型---Node:
由于只用1个栈,感觉1个栈用来判断括号更清晰一些。栈中要存矩阵、'('两种类型,所以建了一个同一类型Node,矩阵和'('只会有1个不为null。
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.*; /** * HJ70 矩阵乘法计算量估算 */ public class Main { /** * 借助栈 * 过程 * i从左到右遍历字符串 * 1. 如果a[i]==字母 * (1)如果栈顶不为矩阵,入栈 * (2)如果栈顶为矩阵,栈顶出栈,和a[i]计算,累加数量,结果矩阵入栈 * 2. 如果a[i]=='(' * 入栈 * 由于第1条,所以如果有一连串矩阵,随着i的移动,栈中'('以上只会有1个矩阵,栈类似这样: * | 矩阵 | * | ( | * | 矩阵 | * | ( | * | ..... | * | ..... | * |_______| * 所以对于')'的规则是: * 3. 如果a[i]==')' * 栈顶矩阵1出栈,'('出栈, * (1)如果此时栈顶是矩阵2,那么矩阵2和矩阵1计算,累加数量,结果矩阵入栈 * (2)如果此时栈顶是'(',那么矩阵1入栈 * 栈中只会存矩阵、'('两种类型,不会存')' * @param matrixInfos * @param orderStr * @return */ private static int calculateAmount(List<MatrixInfo> matrixInfos, String orderStr) { // 字母和矩阵对应关系 Map<Character, MatrixInfo> matrixInfoMap = new HashMap<>(matrixInfos.size()); String letterStr = orderStr.replaceAll("[()]", ""); char[] chars = letterStr.toCharArray(); // 按照字母顺序排序 Arrays.sort(chars); for (int i = 0; i < chars.length; i++) { matrixInfoMap.put(chars[i], matrixInfos.get(i)); } // 栈 Deque<Node> deque = new ArrayDeque<>(); char[] orderChars = orderStr.toCharArray(); // 累计计算量 int total = 0; for (char c : orderChars) { // 如果是字母,判断栈顶是否为矩阵,如果是,运算,运算结果放入栈顶 if (c >= 'A' && c <= 'Z') { Node pNode = deque.peekLast(); // 如果栈为空,或者栈顶为'(',入栈 if (pNode == null || (pNode.bracket != null && pNode.bracket == '(')) { MatrixInfo matrixInfo = matrixInfoMap.get(c); Node node = new Node(matrixInfo); deque.addLast(node); } // 栈不为空,且栈顶为矩阵,栈顶出栈,和a[i]计算,入栈 else { Node node = deque.pollLast(); MatrixInfo cInfo = matrixInfoMap.get(c); Amount r = calculateAmount(node.matrixInfo, cInfo); // 结果累加 total = total + r.calAmount; // 创建新的节点,入栈 Node tNode = new Node(r.matrixInfo); deque.addLast(tNode); } } // 如果是'(',入栈 else if (c == '(') { deque.addLast(new Node('(')); } // 如果是')' else { // 栈顶矩阵1出栈,'('出栈 Node matrixNode = deque.pollLast(); Node bracketNode = deque.pollLast(); // (1)如果此时栈顶是矩阵2,那么矩阵2出栈,和矩阵1计算,结果矩阵入栈 if (!deque.isEmpty() && deque.peekLast().matrixInfo != null) { Node node = deque.pollLast(); Amount amount = calculateAmount(node.matrixInfo, matrixNode.matrixInfo); // 累加结果 total = total + amount.calAmount; // 创建新节点,入栈 deque.addLast(new Node(amount.matrixInfo)); } // (2)如果此时栈为空(实际不会有这种情况)或者栈顶是'(',那么矩阵1入栈 else { deque.addLast(matrixNode); } } } // i走完 return total; } /** * 2个矩阵的计算量,a*b * @param a * @param b * @return */ private static Amount calculateAmount(MatrixInfo a, MatrixInfo b) { // 结果矩阵行列数 MatrixInfo r = new MatrixInfo(a.x, b.y); // 计算量。假设a是x行y列,b是y行z列,结果是x行z列 // 结果是x行z列,那么一共x*z个元素 // 计算每个元素需要的乘法数量:a的一行*b的一列 --- y个数和y个数的乘积和 --- y次乘法 // 所以乘法数量=y*(x*z) int amount = a.y * (a.x * b.y); return new Amount(r, amount); } /** * 计算量 */ static class Amount { MatrixInfo matrixInfo; // 计算量 int calAmount = 0; public Amount(MatrixInfo matrixInfo) { this.matrixInfo = matrixInfo; } public Amount(MatrixInfo matrixInfo, int calAmount) { this.matrixInfo = matrixInfo; this.calAmount = calAmount; } } /** * 栈中存储的节点 * 如果是矩阵,那么matrixInfo!=null,bracket==null * 如果是括号,那么bracket!=null,matrixInfo==null */ static class Node { // 矩阵 MatrixInfo matrixInfo; // 括号 Character bracket; public Node(MatrixInfo matrixInfo) { this.matrixInfo = matrixInfo; } public Node(Character bracket) { this.bracket = bracket; } } static class MatrixInfo { // 矩阵行数 int x; // 矩阵列数 int y; public MatrixInfo(int x, int y) { this.x = x; this.y = y; } } public static void main(String[] args) { // List<MatrixInfo> matrixInfos = new ArrayList<>(); // matrixInfos.add(new MatrixInfo(50, 10)); // matrixInfos.add(new MatrixInfo(10, 20)); // matrixInfos.add(new MatrixInfo(20, 5)); // String orderStr = "(A(BC))"; // int r = calculateAmount(matrixInfos, orderStr); // System.out.println(r); // 预期3500 // List<MatrixInfo> matrixInfos = new ArrayList<>(); // matrixInfos.add(new MatrixInfo(8, 6)); // matrixInfos.add(new MatrixInfo(6, 14)); // String orderStr = "(AB)"; // int r = calculateAmount(matrixInfos, orderStr); // System.out.println(r); // 预期672 try (BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in))) { String ns = bufferedReader.readLine(); int n = Integer.parseInt(ns); List<MatrixInfo> list = new ArrayList<>(); for (int i = 0; i < n; i++) { String matrixS = bufferedReader.readLine(); String[] a = matrixS.split(" "); list.add(new MatrixInfo(Integer.parseInt(a[0]), Integer.parseInt(a[1]))); } String orderStr = bufferedReader.readLine(); int result = calculateAmount(list, orderStr); System.out.println(result); } catch (IOException e) { throw new RuntimeException(e); } } }