Skip to main content

Grouped GEMM

Project description

Grouped GEMM

A lighweight library exposing grouped GEMM kernels in PyTorch.

Installation

Run pip install grouped_gemm to install the package.

Compiling from source

By default, the installed package runs in conservative (cuBLAS) mode: it launches one GEMM kernel per batch element instead of using a single grouped GEMM kernel for the whole batch.

To enable using grouped GEMM kernels, you need to switch to the CUTLASS mode by setting the GROUPED_GEMM_CUTLASS environment variable to 1 when building the library. For example, to build the library in CUTLASS mode for Ampere (SM 8.0), clone the repository and run the following:

$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install .

See this comment for some performance measurements on A100 and H100.

Benchmark example

python benchmark.py

Upcoming features

  • Hopper-optimized grouped GEMM kernels.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

grouped_gemm-0.3.0.tar.gz (981.1 kB view details)

Uploaded Source

File details

Details for the file grouped_gemm-0.3.0.tar.gz.

File metadata

  • Download URL: grouped_gemm-0.3.0.tar.gz
  • Upload date:
  • Size: 981.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.10

File hashes

Hashes for grouped_gemm-0.3.0.tar.gz
Algorithm Hash digest
SHA256 f0555da33a975610e9160f52449082c5e699525524581a0e5f990b14b02676a3
MD5 4df6a6c73e48d9931dd1a10c76f4c9bd
BLAKE2b-256 1a90d255544a8da444fdfab7287850316d2c7961003586a35d3042787982e66c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page