题解 | #矩阵乘法计算量估算#二叉树解决

矩阵乘法计算量估算

https://www.nowcoder.com/practice/15e41630514445719a942e004edc0a5b

#二叉树那个没整形数据,我整形了下可以用了
from string import ascii_uppercase

input_seq = [
    '3', '50 10', '10 20', '20 5', '((AB)C)',
    '3', '50 10', '10 20', '20 5', '(A(BC)',
]


class BTNode:
    def __init__(self, v):
        self._data = v
        self._left = None
        self._right = None

    def insert_left(self, v):
        if self._left is None:
            self._left = BTNode(v)
        else:
            t = BTNode(v)
            t._left = self._left
            self._left = t

    def insert_right(self, v):
        if self._right is None:
            self._right = BTNode(v)
        else:
            t = BTNode(v)
            t._right = self._right
            self._right = t

    def preorder(self):
        print(self._data)
        if self._left:
            self._left.preorder()
        if self._right:
            self._right.preorder()

    def inorder(self):
        if self._left:
            self._left.inorder()
        print(self._data)
        if self._right:
            self._right.inorder()

    def postorder(self):
        if self._left:
            self._left.postorder()
        if self._right:
            self._right.postorder()
        print(self._data)

    @property
    def left(self):
        return self._left

    @property
    def right(self):
        return self._right

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, v):
        self._data = v


# 通过tree的全括号表达式解析算法解析计算顺序
def build_parse_tree(expression):
    exp_list = list(expression)
    parse_stack = []
    root_node = BTNode('')
    parse_stack.append(root_node)
    current_node = root_node

    i = 0
    while i < len(exp_list):
        token = exp_list[i]
        if token == '(':
            current_node.insert_left('')
            parse_stack.append(current_node)
            current_node = current_node.left
        elif token == ')':
            current_node = parse_stack.pop()
            try:
                if exp_list[i+1] != ')':
                    current_node.insert_right('')
                    parse_stack.append(current_node)
                    current_node = current_node.right
            except IndexError:
                pass
        else:
            current_node.data = token
            current_node = parse_stack.pop()
            # current_node.data = token
            try:
                if exp_list[i+1] != ')':
                    current_node.insert_right('')
                    parse_stack.append(current_node)
                    current_node = current_node.right
            except IndexError:
                pass
        i += 1
    return root_node


def evaluate(node, matrixes, total):
    left = node.left
    right = node.right

    if left and right:
        left_arr = evaluate(left, matrixes, total)
        right_arr = evaluate(right, matrixes, total)
        total[0] += left_arr[0] * right_arr[0] * right_arr[1]
        return [left_arr[0], right_arr[1]]
    else:
        return matrixes[node.data]


def restore_exp(node):
    exp = ''
    if node.left and node.right:
        exp = '(' + restore_exp(node.left)
        exp += str(node.data)
        exp += restore_exp(node.right) + ')'
    else:
        exp += str(node.data)

    return exp


def execute(matrixes, expression):
    node = build_parse_tree(expression)
    # node.preorder()
    # print('---------------------------')
    # node.inorder()
    # print('---------------------------')
    # node.postorder()
    # print('---------------------------')
#     print(restore_exp(node))
    total = [0]
    evaluate(node, matrixes, total)
    print(total[0])


def matrix_multiplication_computation(seq):
    i = 0
    while i < len(seq):
        n = int(seq[i])
        matrixes = {}
        count = 0
        for j in range(i+1, i+1+n):
            matrixes[ascii_uppercase[count]] = [int(n) for n in seq[j].split(' ')]
            count += 1
        expression = seq[i+1+n]
        execute(matrixes, expression)
        i += 2+n

n = int(input())
arr = []
arr.append(str(n))
for i in range(n+1):
    arr.append(input())
n1= matrix_multiplication_computation(arr)

全部评论

相关推荐

头像
11-26 15:46
已编辑
中南大学 后端
字节国际 电商后端 24k-35k
点赞 评论 收藏
分享
评论
1
收藏
分享
牛客网
牛客企业服务