pypto-case-matvec

star 254

矩阵-向量乘法:K > 65535 时用 elementwise mul + sum 替代 matmul

mindspore-ai By mindspore-ai schedule Updated 3/2/2026

name: pypto-case-matvec description: "矩阵-向量乘法:K > 65535 时用 elementwise mul + sum 替代 matmul" category: case version: "1.0.0" metadata: backend: ascend dsl: pypto operator_patterns: "matrix_vector,matvec,large_k"

Matrix-Vector Multiplication (K > 65535)

A: (256, 131072), B: (131072, 1) -> C: (256, 1)

K=131072 超过 pypto.matmul 限制(最后一维 <= 65535),用 sum(a * b_row, dim=1) 替代。

def create_matvec_sum_kernel(a_shape, b_shape):
    out_shape = (a_shape[0], 1)

    @pypto.frontend.jit(...)
    def matvec_sum_kernel(
            a: pypto.Tensor(a_shape, pypto.DT_FP32),
            b_row: pypto.Tensor(b_shape, pypto.DT_FP32),
    ) -> pypto.Tensor(out_shape, pypto.DT_FP32):
        output = pypto.tensor(list(out_shape), pypto.DT_FP32)
        pypto.set_vec_tile_shapes(1, 8192)
        output[:] = pypto.sum(a * b_row, dim=1, keepdim=True)
        return output
    return matvec_sum_kernel

class ModelNew(torch.nn.Module):
    def forward(self, A, B):
        assert A.dim() == 2
        assert tuple(A.shape) == (256, 131072)
        assert B.dim() == 2
        assert tuple(B.shape) == (131072, 1)
        A = A.contiguous()
        # B: (K, 1) -> (1, K) 用于广播乘法
        B_row = B.contiguous().reshape(1, -1)
        return create_matvec_sum_kernel(tuple(A.shape), tuple(B_row.shape))(A, B_row)

关键点:forward 中 B.reshape(1, -1) 将列向量转为行向量,使 a * b_row 可广播。

Install via CLI
npx skills add https://github.com/mindspore-ai/akg --skill pypto-case-matvec
Repository Details
star Stars 254
call_split Forks 48
navigation Branch main
article Path SKILL.md
More from Creator
mindspore-ai
mindspore-ai Explore all skills →