Skip to main content

PyTorch bindings for CUDA-Warp RNN-Transducer

Project description

PyTorch bindings for CUDA-Warp RNN-Transducer

def rnnt_loss(
        log_probs,  # type: torch.FloatTensor
        labels,  # type: torch.IntTensor
        frames_lengths,  # type: torch.IntTensor
        labels_lengths,  # type: torch.IntTensor
        average_frames=False,  # type: bool
        reduction=None,  # type: Optional[AnyStr]
        blank=0,  # type: int
):
    """The CUDA-Warp RNN-Transducer loss.

    Args:
      log_probs (torch.Tensor): Input tensor (float) with shape
        (T, N, U, V) where T is the maximum number of input frames, N is the
        minibatch size, U is the maximum number of output labels and V is
        the vocabulary of labels (including the blank).
      labels (torch.IntTensor): Tensor with shape (N, U-1) representing the
        reference labels for all samples in the minibatch.
      frames_lengths (torch.IntTensor): Tensor with shape (N,) representing the
        number of frames for each sample in the minibatch.
      labels_lengths (torch.IntTensor): Tensor with shape (N,) representing the
        length of the transcription for each sample in the minibatch.
      average_frames (bool, optional): Specifies whether the loss of each
        sample should be divided by its number of frames. Default: ``False''.
      reduction (string, optional): Specifies the type of reduction.
        Default: None.
      blank (int, optional): label used to represent the blank symbol.
        Default: 0.
    """
    # type: (...) -> torch.Tensor

Requirements

  • C++11 compiler (tested with GCC 5.4).
  • Python: 3.5, 3.6, 3.7 (tested with version 3.6).
  • PyTorch >= 1.0.0 (tested with version 1.1.0).
  • CUDA Toolkit (tested with version 10.0).

Install

Currently, there is no compiled version of the package. The following setup instructions compile the package from the source code locally.

From Pypi

pip install warp_rnnt

From GitHub

git clone https://github.com/1ytic/warp-rnnt
cd warp-rnnt/pytorch_binding
python setup.py install

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

warp_rnnt-0.1.0.tar.gz (8.8 kB view details)

Uploaded Source

File details

Details for the file warp_rnnt-0.1.0.tar.gz.

File metadata

  • Download URL: warp_rnnt-0.1.0.tar.gz
  • Upload date:
  • Size: 8.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/40.8.0 requests-toolbelt/0.9.1 tqdm/4.32.1 CPython/3.6.8

File hashes

Hashes for warp_rnnt-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e3637a89fcefdbf5a456867bd568cab67d6c911167a756af304eb0877c329b82
MD5 a2a057e72e43e8aec492fa8242f1b839
BLAKE2b-256 866d1389db3abbeaaed0279516878275b7c8f536c70323376bfd46a6e95e4040

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page