ConvBwdDataImplicitGemmOutTransform Struct Reference

ConvBwdDataImplicitGemmOutTransform Struct Reference#

Composable Kernel: ck::ConvBwdDataImplicitGemmOutTransform Struct Reference

Transformation struct for convolution backward data output indices to GEMM indices. More...

#include <multi_index_transform.hpp>

Public Types

using LowerIndex = MultiIndex<4>
using UpperIndex = MultiIndex<3>

Public Member Functions

__host__ __device__ ConvBwdDataImplicitGemmOutTransform ()=default
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform (index_t N, index_t Ho, index_t Wo, index_t K, index_t XDot, index_t HTilde, index_t WTilde, index_t WTildeSlice, index_t HWTildeSlice, index_t IHTildeSliceBegin, index_t IWTildeSliceBegin, index_t HRatio, index_t WRatio, index_t XDotSlice_K, index_t K0, index_t MPadded, index_t K1, index_t MPad, index_t KPad)
__host__ __device__ constexpr const auto & GetUpperLengths () const
template<typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexN (const UpIdx &idx_up) const
template<typename UpIdx>
__host__ __device__ constexpr auto CalculateLowerIndexK (const UpIdx &idx_up) const
template<typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex (LowIdx &idx_low, const UpIdx &idx_up) const
template<typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx, index_t Hack>
__host__ __device__ void UpdateLowerIndex (LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up, Number< Hack >) const
template<typename UpIdx>
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex (const UpIdx &idx_up) const
__host__ __device__ void Print () const

Static Public Member Functions

__host__ static __device__ constexpr index_t GetNumOfLowerDimension ()
__host__ static __device__ constexpr index_t GetNumOfUpperDimension ()
__host__ static __device__ constexpr bool IsLinearTransform ()
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex ()
__host__ static __device__ constexpr bool IsKnownAtCompileTime ()

Public Attributes

index_t N_
index_t Ho_
index_t Wo_
index_t K_
index_t XDot_
index_t HTilde_
index_t WTilde_
index_t WTildeSlice_
index_t TildeSlice_
index_t IHTildeSliceBegin_
index_t IWTildeSliceBegin_
index_t HRatio_
index_t WRatio_
index_t XDotSlice_K_
index_t MPad_
index_t KPad_
Tuple< index_t, index_t, index_tup_lengths_
Tuple< index_t, index_t, index_t, index_tlow_lengths_magic_divisor_multiplier_
Tuple< index_t, index_t, index_t, index_tlow_lengths_magic_divisor_shift_

Static Public Attributes

static constexpr auto I0 = Number<0>{}
static constexpr auto I1 = Number<1>{}
static constexpr auto I2 = Number<2>{}
static constexpr auto I3 = Number<3>{}

Detailed Description

Transformation struct for convolution backward data output indices to GEMM indices.

This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the convolution backward data operation to the corresponding indices (K0, M, K1) used in the implicit GEMM computation. It encapsulates the necessary parameters and transformation logic required to efficiently perform the index conversion.

Member Typedef Documentation

◆ LowerIndex

◆ UpperIndex

Constructor & Destructor Documentation

◆ ConvBwdDataImplicitGemmOutTransform() [1/2]

__host__ __device__ ck::ConvBwdDataImplicitGemmOutTransform::ConvBwdDataImplicitGemmOutTransform ( )
default

◆ ConvBwdDataImplicitGemmOutTransform() [2/2]

__host__ __device__ constexpr ck::ConvBwdDataImplicitGemmOutTransform::ConvBwdDataImplicitGemmOutTransform ( index_t N,
index_t Ho,
index_t Wo,
index_t K,
index_t XDot,
index_t HTilde,
index_t WTilde,
index_t WTildeSlice,
index_t HWTildeSlice,
index_t IHTildeSliceBegin,
index_t IWTildeSliceBegin,
index_t HRatio,
index_t WRatio,
index_t XDotSlice_K,
index_t K0,
index_t MPadded,
index_t K1,
index_t MPad,
index_t KPad )
inlineconstexpr

Member Function Documentation

◆ CalculateLowerIndex()

template<typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void ck::ConvBwdDataImplicitGemmOutTransform::CalculateLowerIndex ( LowIdx & idx_low,
const UpIdx & idx_up ) const
inlineconstexpr

◆ CalculateLowerIndexK()

template<typename UpIdx>
__host__ __device__ constexpr auto ck::ConvBwdDataImplicitGemmOutTransform::CalculateLowerIndexK ( const UpIdx & idx_up) const
inlineconstexpr

◆ CalculateLowerIndexN()

template<typename UpIdx>
__host__ __device__ constexpr auto ck::ConvBwdDataImplicitGemmOutTransform::CalculateLowerIndexN ( const UpIdx & idx_up) const
inlineconstexpr

◆ GetNumOfLowerDimension()

__host__ static __device__ constexpr index_t ck::ConvBwdDataImplicitGemmOutTransform::GetNumOfLowerDimension ( )
inlinestaticconstexpr

◆ GetNumOfUpperDimension()

__host__ static __device__ constexpr index_t ck::ConvBwdDataImplicitGemmOutTransform::GetNumOfUpperDimension ( )
inlinestaticconstexpr

◆ GetUpperLengths()

__host__ __device__ constexpr const auto & ck::ConvBwdDataImplicitGemmOutTransform::GetUpperLengths ( ) const
inlineconstexpr

◆ IsKnownAtCompileTime()

__host__ static __device__ constexpr bool ck::ConvBwdDataImplicitGemmOutTransform::IsKnownAtCompileTime ( )
inlinestaticconstexpr

◆ IsLinearTransform()

__host__ static __device__ constexpr bool ck::ConvBwdDataImplicitGemmOutTransform::IsLinearTransform ( )
inlinestaticconstexpr

◆ IsValidUpperIndexAlwaysMappedToValidLowerIndex()

__host__ static __device__ constexpr bool ck::ConvBwdDataImplicitGemmOutTransform::IsValidUpperIndexAlwaysMappedToValidLowerIndex ( )
inlinestaticconstexpr

◆ IsValidUpperIndexMappedToValidLowerIndex()

template<typename UpIdx>
__host__ __device__ constexpr bool ck::ConvBwdDataImplicitGemmOutTransform::IsValidUpperIndexMappedToValidLowerIndex ( const UpIdx & idx_up) const
inlineconstexpr

◆ Print()

__host__ __device__ void ck::ConvBwdDataImplicitGemmOutTransform::Print ( ) const
inline

◆ UpdateLowerIndex()

template<typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx, index_t Hack>
__host__ __device__ void ck::ConvBwdDataImplicitGemmOutTransform::UpdateLowerIndex ( LowIdxDiff & idx_diff_low,
const UpIdxDiff & ,
LowIdx & idx_low,
const UpIdx & idx_up,
Number< Hack >  ) const
inline

Member Data Documentation

◆ Ho_

index_t ck::ConvBwdDataImplicitGemmOutTransform::Ho_

◆ HRatio_

index_t ck::ConvBwdDataImplicitGemmOutTransform::HRatio_

◆ HTilde_

index_t ck::ConvBwdDataImplicitGemmOutTransform::HTilde_

◆ I0

auto ck::ConvBwdDataImplicitGemmOutTransform::I0 = Number<0>{}
staticconstexpr

◆ I1

auto ck::ConvBwdDataImplicitGemmOutTransform::I1 = Number<1>{}
staticconstexpr

◆ I2

auto ck::ConvBwdDataImplicitGemmOutTransform::I2 = Number<2>{}
staticconstexpr

◆ I3

auto ck::ConvBwdDataImplicitGemmOutTransform::I3 = Number<3>{}
staticconstexpr

◆ IHTildeSliceBegin_

index_t ck::ConvBwdDataImplicitGemmOutTransform::IHTildeSliceBegin_

◆ IWTildeSliceBegin_

index_t ck::ConvBwdDataImplicitGemmOutTransform::IWTildeSliceBegin_

◆ K_

index_t ck::ConvBwdDataImplicitGemmOutTransform::K_

◆ KPad_

index_t ck::ConvBwdDataImplicitGemmOutTransform::KPad_

◆ low_lengths_magic_divisor_multiplier_

Tuple<index_t, index_t, index_t, index_t> ck::ConvBwdDataImplicitGemmOutTransform::low_lengths_magic_divisor_multiplier_

◆ low_lengths_magic_divisor_shift_

Tuple<index_t, index_t, index_t, index_t> ck::ConvBwdDataImplicitGemmOutTransform::low_lengths_magic_divisor_shift_

◆ MPad_

index_t ck::ConvBwdDataImplicitGemmOutTransform::MPad_

◆ N_

index_t ck::ConvBwdDataImplicitGemmOutTransform::N_

◆ TildeSlice_

index_t ck::ConvBwdDataImplicitGemmOutTransform::TildeSlice_

◆ up_lengths_

Tuple<index_t, index_t, index_t> ck::ConvBwdDataImplicitGemmOutTransform::up_lengths_

◆ Wo_

index_t ck::ConvBwdDataImplicitGemmOutTransform::Wo_

◆ WRatio_

index_t ck::ConvBwdDataImplicitGemmOutTransform::WRatio_

◆ WTilde_

index_t ck::ConvBwdDataImplicitGemmOutTransform::WTilde_

◆ WTildeSlice_

index_t ck::ConvBwdDataImplicitGemmOutTransform::WTildeSlice_

◆ XDot_

index_t ck::ConvBwdDataImplicitGemmOutTransform::XDot_

◆ XDotSlice_K_

index_t ck::ConvBwdDataImplicitGemmOutTransform::XDotSlice_K_

The documentation for this struct was generated from the following file: