题解 | 矩阵乘法计算量估算

import sys
from typing import List

def get_m_size(C:str) -> List[int]:
    return matrix_sizes[ord(C) - ord('A')]

def matrix_multiply_times(size_lst:List[int]) -> List[int]:
    size = size_lst[0][0], size_lst[-1][-1]
    res = 0
    for i in range(1, len(size_lst)):
        res += size_lst[i-1][0] * size_lst[i][0] * size_lst[i][1]
    return size, res
   

raw_input = []
for i,line in enumerate(sys.stdin):
    raw_input.append(line.strip())

N = int(raw_input[0])
matrix_sizes = [[int(i) for i in row.split(' ')] for row in raw_input[1:1+N]]
eval_cmd = raw_input[1+N]
queue = []
res = 0
for c in eval_cmd:
    if c == '(':
        queue.append(c)
    elif c == ')':
        size_lst = []
        while queue[-1] != '(':
            size_lst.append(queue.pop())
            size, calc = matrix_multiply_times(size_lst[::-1])
            res += calc
        queue.pop()
        queue.append(size)
    else:
        queue.append(get_m_size(c))
print(res)

全部评论

相关推荐

2024-12-12 19:01
西安交通大学 Java
安克创新 java岗 年包30
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客企业服务