首页 > 试题广场 >

矩阵乘法计算量估算

[编程题]矩阵乘法计算量估算
  • 热度指数:77080 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 32M,其他语言64M
  • 算法知识视频讲解

矩阵乘法的运算量与矩阵乘法的顺序强相关。
例如:

A是一个50×10的矩阵,B是10×20的矩阵,C是20×5的矩阵

计算A*B*C有两种顺序:((AB)C)或者(A(BC)),前者需要计算15000次乘法,后者只需要3500次。

编写程序计算不同的计算顺序需要进行的乘法次数。

数据范围:矩阵个数:,行列数:保证给出的字符串表示的计算顺序唯一
进阶:时间复杂度:,空间复杂度:



输入描述:

输入多行,先输入要计算乘法的矩阵个数n,每个矩阵的行数,列数,总共2n的数,最后输入要计算的法则
计算的法则为一个字符串,仅由左右括号和大写字母('A'~'Z')组成,保证括号是匹配的且输入合法!



输出描述:

输出需要进行的乘法次数

示例1

输入

3
50 10
10 20
20 5
(A(BC))

输出

3500
#n接收矩阵个数
n = int(input())
#juzhen接收矩阵行列数据,里面的每一个元素是一个有2个元素的列表,表示每一个矩阵的行列数
juzhen = []
#循环接收输入的每一个矩阵的行列数据
for i in range(n):
    l = list(map(int,input().split()))
    juzhen.append(l)
#s接收计算的法则
s = input()
#res存储计算量的结果
res = 0
#tmp用于模拟栈
tmp = []
#从头开始往后遍历计算法则字符串
#遇到字母就将对应的矩阵入栈
#遇到)就将最后两个入栈的矩阵出栈用于计算
#计算完成后将得到的新矩阵入栈
for i in range(len(s)):
    #如果遇到字母
    if s[i].isalpha():
        #计算是哪个字母
        j = ord(s[i])-65
        #将对应字母所表示的矩阵行列数据放进栈中
        tmp.append(juzhen[j])
    #如果遇到),说明该进行计算了
    elif s[i] == ')':
        #将栈的最后放进去的两个数据,也就是栈尾的两个数据用于计算
        #两个矩阵的计算量存放在res中
        res += tmp[-2][0]*tmp[-2][1]*tmp[-1][1]
        #两个矩阵计算完成后得到一个新的矩阵
        #将新的矩阵替换之前的两个矩阵
        #也就是说,用于计算的两个矩阵数据要出栈
        #新的矩阵要入栈
        tmp[-2] = [tmp[-2][0],tmp[-1][1]]
        tmp.pop(-1)
print(res)

发表于 2022-06-15 21:21:14 回复(0)
while True:
    try:
        s=int(input())
        a=[]
        for i in range(s):
            a.append(list(map(int, input().split())))
        b=input()
        c=[]
        j=0
        count=0
        for i in b:
            if i.isalpha():
                c.append(a[j])
                j+=1
            elif i==')' and len(c)>=2:
                a2=c.pop()
                a1=c.pop()
                count += a1[0]*a1[1]*a2[1]
                c.append([a1[0],a2[1]])
        print(count)
    except:
        break
发表于 2021-05-08 22:08:48 回复(0)
案例有错误,左右括号不齐
while True:
    try:
        num=int(input())
        list1=[[0, 0] for i in range(num)]
        for i in range(num):
            list1[i]=list(map(int,input().split()))
        str1=input()
        list2=[]
        c=0
        for i in str1:
            if i.isalpha():
                list2.append(i)
            if i==')' and len(list2)>=2:
                k=len(list2)
                w=list1[k-2][0]*list1[k-2][1]*list1[k-1][1]
                list1[k-2][1]=list1[k-1][1]
                list2.pop()
                del list1[len(list2)]
                c+=w
        print(c)
    except:
        break

发表于 2021-03-25 21:07:15 回复(0)
自己想了一下
还是参考了答案,比较巧妙,原来列表可以直接谈栈的,保存数据自己写驾轻就熟
关键在谈栈那块代码,代码-65也很巧妙,表示第几个数

if __name__ == '__main__':
    
    while True:
        try:
            n = input()
            order = []
            res = 0
            arr = [[0 for i in xrange(2)] for i in xrange(n)]
            for i in xrange(n):
                a_b = raw_input()
                a = int(a_b.split(' ')[0])
                b = int(a_b.split(' ')[1])
                arr[i][0]  = a
                arr[i][1]  = b
            s = raw_input()
            for i in s:
                if i.isalpha():
                    order.append(arr[ord(i)-65])
                elif i==')'  and len(order) >=2:
                    a=order.pop()
                    b = order.pop()
#                     print b
                    res+=b[0]*b[1]*a[1]
                    order.append([b[0],a[1]])
            print res
        except:
            break


发表于 2021-03-13 22:57:25 回复(0)
不要相信题目中说的“保证括号合法”这句话,有一个测试用例多了一个右括号
from collections import deque


while True:
    try:
        n = int(input())
        ms = []
        for i in range(n):
            ms.append(list(map(int, input().split())))
        stack, count = deque(), 0
        for c in input():
            if c.isalpha():
                stack.append(ms.pop(0))
            elif c == ")":
                if len(stack) < 2:
                    continue
                b, a = stack.pop(), stack.pop()
                count += a[0]*a[1]*b[1]
                stack.append([a[0], b[1]])
        print(count)
    except:
        break


发表于 2020-12-18 15:03:17 回复(0)
import re
letterlist=['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']

while 1:
    try:
        a=int(input())
        dimmat=[]
        for each in range(a):
            dimmat.append(list(map(int, input().split())))
        rule=input()


        reg=re.compile('\\(\w{2}\\)')
        prosum=0
        while len(rule)>2:
            curr=reg.findall(rule)
            for item in curr:
                currstr=item
                currL1=item[1]
                ind=letterlist.index(currL1)
                val1=dimmat[ind][0]
                val2=dimmat[ind][1]
                val3=dimmat[ind+1][1]
                prosum=prosum+val1*val2*val3
                newmatdim=[val1,val3]
                dimmat[ind]=newmatdim
                del dimmat[ind+1]
                del letterlist[ind+1]
                rule=rule.replace(currstr,currL1)
        print(prosum)
    except:
        break
        
有没有朋友帮我看看这个哪里错了?在本地都可以通过 谢谢了🙏
发表于 2020-12-16 16:05:02 回复(0)
from functools import reduce

def cal(a,b):
    """
    计算两个行列式,并更新计算数
    并将更新后的行列式更新到记录,供返回后使用
    """
    global count
    global record
    global max_idx
    ar1 = record[a]
    ar2 = record[b]
    count += (ar1[0]*ar1[1]*ar2[1])
    max_idx += 1
    #需要单个字符的标记
    idx = chr(ord("A")+max_idx)
    record[idx] = (ar1[0],ar2[1])
    return idx
    return idx

def cal_count(s):
    #返回行列式的索引
    if "(" not in s:
        idx = reduce(cal, s)
        return idx
    else:
        left = s.index("(")
        right = s.rindex(")")
        tem = cal_count(s[left+1:right])
        return cal_count(s[:left]+tem+s[right+1:])

if __name__ == '__main__':
    while True:
        try:
            count = 0
            n = int(input().strip())
            record = {}
            max_idx = 0
            for i in range(n):
                max_idx += 1
                record[chr(ord('A')+i)] = tuple(map(int, input().strip().split(" ")))
            cal_count(input().strip())
            print(count)
        except:
            break
有一个用例左右括号根本不一样吧
发表于 2020-12-14 00:02:38 回复(0)
# (A((BC)(DE)))
# 涉及()运算,可以用栈的方式,非)入栈,)出栈
def matrix_multiply(data,rule):
    stack,cal=[],0
    for i in range(len(rule)):
        if rule[i] != ')':
            stack.append(rule[i]) # 建立栈存放运算规则rule
        else:
            a = stack.pop() # 右扩号前1
            b = stack.pop() # 右扩号前2
            stack.pop() # 删掉左扩号
            cal += data[b][0]*data[b][1]*data[a][1] # (m,n)*(n,o) = m*n*o,计算算量
            if stack:
                # 如果stack不为空,说明还未算完,b*a添加至stack
                data[str(i)] = [data[b][0],data[a][1]]
                stack.append(str(i))
            else:
                return cal
while True:
    try:
        n = int(input())
        matrix = {}
发表于 2020-12-06 00:26:29 回复(0)
测试用例似乎全部是错误的,错误点在于给的计算公式全部都多了一个')',在程序中强行删除掉这个多余的括号后,运行AC通过。
发表于 2020-11-09 20:18:29 回复(0)
def get_sum(s, m, r):
    if s[0] == '(' and s[3] == ')':  # 推出条件 (AZ)
        return r + m[0][0] * m[0][1] * m[1][1]  # [[1,2],[2,3]] 1*2*3
    for index, item in enumerate(s):
        if s[index] == '(' and s[index + 3] == ')':  # 题目隐含条件,一对最近的括号中有且仅有两个字母 eg: (A(BC))
            mi = index // 2  # 找到'('对应矩阵的matrix index
            s1 = s[:index] + 'Z' + s[index + 4:]  # (A(BC)) => (AZ)
            m1 = m[:mi] + [[m[mi][0], m[mi + 1][1]]]  # [[1,2],[2,3],[3,4]] => [[1,2],[2,4]]
            r1 = r + m[mi][0] * m[mi][1] * m[mi + 1][1]  # r + 2*3*4
            return get_sum(s1, m1, r1)


while True:
    try:
        n = int(input())
        m = [[int(_) for _ in input().split()] for _ in range(n)]  # 生成记录矩阵 [[1,2],[2,3]]
        s = input()  # 计算法则 (A(BC))
        print(get_sum(s, m, 0))
    except:
        break


编辑于 2020-08-31 16:16:06 回复(0)
计算器翻版,有一个数字栈和一个符号栈,这里是遇到相邻两个字母直接计算
计算时从数字栈里弹出两个size,从符号栈里弹出两个符号,再把算好的size加进数字栈,把一个新的符号加进符号栈
遇左括号则向符号栈加左括号
遇右括号则看栈顶然后计算,注意计算过程中有A(B)的情况,这个时候直接从符号栈里弹掉左括号和B,再把B加进来,算一下
本方法兼容不加括号的连乘,如(A(BCD)E(FG)H)
另外看评论区有测试用例右括号多于左括号的情况(测试的输入不正确),难怪提交了第一次返回非零
def matrix_mul_times(_sizes, _seq):
    var_stack = []
    number_stack = []
    number_times = 0
    
    def _update():
        local_times = 0
        while len(var_stack) >= 2 and ord('A') <= ord(var_stack[-1][0]) and ord('A') <= ord(var_stack[-2][0]):
            current_var = var_stack.pop()
            last_var = var_stack.pop()
            current_matrix_size = number_stack.pop()
            last_matrix_size = number_stack.pop()
            number_stack.append((last_matrix_size[0], current_matrix_size[1]))
            if last_matrix_size[1] != current_matrix_size[0]:
                raise ValueError('Cannot multiply {} * {} and {} * {}'.format(
                    last_matrix_size[0], last_matrix_size[1], current_matrix_size[0], current_matrix_size[1]))
            local_times += last_matrix_size[0] * last_matrix_size[1] * current_matrix_size[1]
            var_stack.append(last_var + current_var)
        return local_times
    
    for _c in _seq:
        if ord('A') <= ord(_c):
            current_matrix_size = _sizes[ord(_c) - ord('A')]
            number_stack.append(current_matrix_size)
            var_stack.append(_c)
        if _c == '(':
            var_stack.append(_c)
        if _c == ')':
            if var_stack:
                if var_stack[-1] == '(':
                    var_stack.pop()
                elif ord('A') <= ord(var_stack[-1][0]):
                    number_times += _update()
                    last_var = var_stack.pop()
                    if var_stack:  # where an incorrect input triggered an exception
                        _left_par = var_stack.pop()
                        if _left_par != '(':
                            var_stack.append(_left_par)
                    var_stack.append(last_var)
                    number_times += _update()
            
    return number_times


发表于 2020-08-10 17:33:52 回复(0)
import re

while True:
    try:
        n = int(input())
        jz = [[int(i) for i in input().split()] for j in range(n)]
        sx = input()
        m = {}
        stack = []
        new = []
        count = 0

        s = re.findall(r'\w', sx)
        for k, value in enumerate(s):
            m[value] = jz[k]

        for l in sx:
            if l == ')':
                if len(stack) < 3:
                    continue
                new.append(stack.pop())
                new.append(stack.pop())
                stack.pop()
                count += m[new[0]][1] * m[new[1]][0] * m[new[1]][1]
                n = new[1] + new[0]
                rc = [m[new[1]][0], m[new[0]][1]]
                m[n] = rc
                stack.append(n)
                new.clear()
                continue
            stack.append(l)

        print(count)



    except:
        break




发表于 2020-08-10 16:25:56 回复(0)
# 测试用例右括号会多一个,如(A(B(C(D(E(F(GH))))))))
# python 栈 解法分享
while True:
    try:
        # (A(BC))
        N = int(input())
        lstMat = []
        for _ in range(N): # 输入N个矩阵
            lstMat.append(list(map(int, input().split())))
        law = list(input()) # 计算规则 ['(', 'A', '(', 'B', 'C', ')', ')']
        for i in law:
            if i.isalpha(): # 把law中的ABC字符替换为实际的ABC列表
                law[law.index(i)] = lstMat.pop(0)
        lst = [] # 栈
        cal = [] # 用于计算的临时列表
        times = 0 # 总的计算次数

        for i in law:
            if i == ')': # 遇到右括号时,进行计算
                if not '(' in lst: # 如果栈中无左括号(这一步if主要是因为测试用例中多了一个**右括号,要不就不用加了)
                    continue
                # 此时 lst = ['(', 'A', '(', 'B', 'C']
                # 出栈操作,直到遇见左括号 (
                while True:
                    x = lst.pop() # 出栈
                    if x == '(':
                        break
                    else:
                        cal.insert(0, x) # 给计算列表添加元素,考虑到计算顺序和出栈顺序(计算顺序的逆序),需要insert(0)
#                 print(cal) 中间的几个打印是在调试时找问题用的~
                # 遇到左括号结束循环,此时 lst = ['(', 'A'], cal = [A, B]
                # 计算弹出内容需要计算的次数
                times += cal[0][0] * cal[0][1] * cal[1][1] # 矩阵1的行数 * 矩阵1的列数 * 矩阵2的列数
#                 print(times)
                # 相乘得到新的矩阵
                Matrix = [cal[0][0], cal[1][1]]
#                 print(Matrix)
                # 把新矩阵入栈
                lst.append(Matrix)
                # 清空计算列表
                cal = []
            else:
                lst.append(i) # 入栈
        print(times)
    except:
        break

发表于 2020-07-16 18:01:51 回复(1)
#看来看去python的答案都一样的,自己写一波
#遇到')'弹栈
def cal_val(A,B):
	if A[1]==B[0]:
		return A[0]*B[1]*A[1]
	else:
		return -1

def cal(A,B):
	if A[1]==B[0]:
		return [A[0],B[1]]
	else:
		return -1

while True:
	try:
		num=int(input())
		array=[]
		for i in range(num):
			matrix=list(map(int,input().split()))
			array.append(matrix)
		array=array[::-1]
		express=str(input().strip())
		stack=[]
		cal_value=0
		for i in range(len(express)):
			if express[i].isalpha():
				stack.append(array.pop())
			elif express[i]==')':
				if len(stack)>1:
					A=stack.pop()
					B=stack.pop()
					print('A,B:',A,B)
					cal_value+=cal_val(B,A)
					stack.append(cal(B,A))
		stack=stack[::-1]
                while len(stack)>1:
			A=stack.pop()
			B=stack.pop()
			cal_value+=cal_val(A,B)
			stack.append(cal(A,B))
		print(cal_value)
	except:
		break

编辑于 2020-06-07 16:25:32 回复(2)
测试用例我也是醉了,多了一个右括号,正常的话最后应该是s仅剩一个字母,所以判断条件为while len(s) !=1,然而只能改为!=2了。。
l = list(map(lambda x:chr(x),range(ord('A'),ord('Z')+1)))
while True:
    try:
        N = eval(input())
        ls = []
        for i in range(N):
            ls.append(list(map(eval,input().split())))
        s = input()
        num = 0
        while len(s) != 2:
            lst = []
            index_ls = -1
            index_s = -1
            for k in s:
                if k == '(':
                    index_s += 1
                    lst.clear()
                elif k in l:
                    index_s += 1
                    index_ls += 1
                    lst.append(k)
                elif k == ')':
                    break
            a = ls[index_ls-1][0]
            b = ls[index_ls][1]
            num += ls[index_ls-1][0]*ls[index_ls-1][1]*ls[index_ls][1]
            ls.insert(index_ls-1,[a,b])
            ls.remove(ls[index_ls])
            ls.remove(ls[index_ls])
            s = s[:index_s-2]+lst[-1]+s[index_s+2:]
        print(num)
    except:
        break

编辑于 2020-05-25 12:04:44 回复(0)
看见很多人说测试用例多了个括号,这个通过判定遇见右括号的时候栈内元素数可以避免溢出问题,大于等于2才弹栈。
利用list实现栈的操作,遇见字符压栈,遇见')'判断栈内元素个数,大于等于2弹栈两次,分别赋值jz2和jz1,计算jz1*jz2要做的乘法次数。计算完乘法次数后,将jz1*jz2的矩阵的行数和列数组成的list压栈,再循环就好了。
def cal_mul(m_list, law):
    stack = []
    jz1 = ''
    jz2 = ''
    cal_all = 0
    for i in law:
        if (i != '(') and (i != ')'):
            stack.append(i)
        elif len(stack) >= 2 and i == ')':
            jz2 = stack.pop(-1)
            jz1 = stack.pop(-1)
            if (isinstance(jz2, list)) and (isinstance(jz1, list)):
                cal_all += jz1[1] * jz2[1] * jz1[0]
                stack.append([jz1[0], jz2[1]])
            elif (isinstance(jz2, list)) and (isinstance(jz1, str)):
                cal_all += m_list[ord(jz1) - ord('A')][1] * jz2[1] * m_list[ord(jz1) - ord('A')][0]
                stack.append([m_list[ord(jz1) - ord('A')][0], jz2[1]])
            elif (isinstance(jz2, str)) and (isinstance(jz1, list)):
                cal_all += jz1[1] * m_list[ord(jz2) - ord('A')][1] * jz1[0]
                stack.append([jz1[0], m_list[ord(jz2) - ord('A')][1]])
            elif (isinstance(jz2, str)) and (isinstance(jz1, str)):
                cal_all += m_list[ord(jz1) - ord('A')][1] * m_list[ord(jz2) - ord('A')][1] * m_list[ord(jz1) - ord('A')][0]
                stack.append([m_list[ord(jz1) - ord('A')][0], m_list[ord(jz2) - ord('A')][1]])
        elif len(stack) < 2 and i == ')':
            return cal_all
    return cal_all


while True:
    try:
        n = int(input())
        m_list = []
        for i in range(n):
            m_list.append([int(x) for x in input().split()])
        law = input()
        print(cal_mul(m_list, law))
    except:
        break


发表于 2020-02-25 21:50:55 回复(0)
while True:
    try:
        n=int(input())
        dic,charA={},65
        for _ in range(n):
            dic[chr(charA)]=list(map(int,input().split()))#将输入的矩阵大小以ABCD...为键存入字典
            charA+=1
        rule=input()
        if rule.count('(')!=rule.count(')'):(738)###有个测试集多了一个右括号
            rule=rule[:-1]
        stack,count=[],0
        for i in rule:
            if i!=')':
                stack.append(i)
            else:(739)#若匹配到右括号
                a=stack.pop()
                b=stack.pop()
                dic['d']=[dic[b][0],dic[a][1]]#将新计算的矩阵形状放入字典中
                stack.pop()(740)#弹出左括号
                if stack:#若栈中没有矩阵,即计算到最后一步,则不添加新的矩阵,否则添加
                    stack.append('d')
                count+=dic[b][0]*dic[b][1]*dic[a][1](741)#两个矩阵(m,n),(n,q)的计算量为m*n*q
        print(count)
    except:
        break

发表于 2020-02-20 21:58:16 回复(1)
#将括号内整体替代为‘z'
def new(x,y): 
    x1=di[x]
    y1=di[y]
    z='z'
    di[z]=[x1[0],y1[1]]
    return (z)
#计算两个矩阵的次数和
def sum(x,y):
    x1=di[x]
    y1=di[y]
    return x1[0]*y1[1]*x1[1]
#更改不正确的命令,遇见’)‘ 溯回找(,在括号范围内整体计算并替换,用替换后的再次运行,直到只剩下’z'
def al(k,count):
    if k.count('(')<k.count(')'):
        k=k.replace(')','',1)
    if len(k)==1:
        return count
    for i in range(len(k)):
        if k[i]==')':
            for j in range(i,-1,-1):
                if k[j]=='(':
                    count+=sum(k[j+1],k[j+2])
                    k=k.replace(k[j:i+1],str(new(k[j+1],k[j+2])))
                    return al(k,count)
while True:
    try:
        num=int(input())
        record=[]
        di={}
        for i in range(num):
            box=input()
            box=list(map(int,box.split()))
            record.append(box)
        cu=input()
#将矩阵对应大写字母
        for i in range(len(record)):
            di[chr(65+i)]=record[i]
        print(al(cu,0))
    except:
        break
发表于 2020-02-09 18:14:53 回复(0)
# coding:utf-8

import sys

def parse(s):
    r = []
    a = []
    for i, ch in enumerate(s):
        if ch.isalpha():
            a.append([i, ch])
        if ch == ')':
            while len(a) > 0:
                ch = a.pop()
                r.append(ch)
    return r

try:
    while True:
        line = sys.stdin.readline().strip()
        if line == '':
            break
        n=int(line)
        a=[]
        for i in range(n):
            line = sys.stdin.readline().strip()
            line=line.split()
            line=[int(ch) for ch in line]
            a.append(line)
        line = sys.stdin.readline().strip()
        r=parse(line)
        n=0
        item1=r.pop(0)
        item2=r.pop(0)
        if item1[0]>item2[0]:
            item1,item2=item2,item1
        m1=a[ord(item1[1])-ord('A')]
        m2=a[ord(item2[1])-ord('A')]
        n+=m1[0]*m1[1]*m2[1]
        tmp=[item1[0],[m1[0],m2[1]]]
        while len(r)>0:
            item=r.pop(0)
            if item[0]>tmp[0]:
                m1=tmp[1]
                m2=a[ord(item[1])-ord('A')]
            else:
                m1=a[ord(item[1])-ord('A')]
                m2=tmp[1]
            n+=m1[0]*m1[1]*m2[1]
            tmp=[item1[0],[m1[0],m2[1]]]
        print(n)
except Exception as e:
    print(e)

编辑于 2019-08-12 13:23:19 回复(0)
while True:
    try:
        n=int(input().strip())
        num_list=[]
        for i in range(n):
            num_list.append(list(map(int,input().strip().split(' '))))
        #print(num_list)
        faze=str(input().strip())
        #print(faze)
        def data(faze,num_list):
            result=0
            for i in range(len(faze)):
                if faze[i]=='(' and faze[i+3]==')':
                    result=num_list[i//2][0]*num_list[i//2][1]*num_list[i//2+1][1]
                    num_list=num_list[:i//2]+[[num_list[i//2][0],num_list[i//2+1][1]]]+num_list[i//2+1:]
                    #print(num_list[:i//2])
                    faze=faze[:i]+'v'+faze[i+4:]
                    return result,num_list,faze
        num=0
        for i in range(n-1):
            result,num_list,faze=data(faze,num_list)
            num+=result
        print(num)
    except:
        break

参考大神的思路

发表于 2019-07-12 00:54:20 回复(0)