@@ -846,11 +846,17 @@ std::string CubRule::getOpRepl(const Expr *Operator) {
846846 auto processOperatorExpr = [&](const Expr *Obj) {
847847 std::string OpType = DpctGlobalInfo::getUnqualifiedTypeName (
848848 Obj->getType ().getCanonicalType ());
849- if (OpType == " cub::Sum" || OpType == " cuda::std::plus<void>" ) {
849+ if (OpType.find (" cub::Sum" ) != std::string::npos ||
850+ OpType.find (" cuda::std::plus" ) != std::string::npos ||
851+ OpType.find (" thrust::plus" ) != std::string::npos) {
850852 OpRepl = MapNames::getClNamespace () + " plus<>()" ;
851- } else if (OpType == " cub::Max" || OpType == " cuda::maximum<void>" ) {
853+ } else if (OpType.find (" cub::Max" ) != std::string::npos ||
854+ OpType.find (" cuda::maximum" ) != std::string::npos ||
855+ OpType.find (" thrust::maximum" ) != std::string::npos) {
852856 OpRepl = MapNames::getClNamespace () + " maximum<>()" ;
853- } else if (OpType == " cub::Min" || OpType == " cuda::minimum<void>" ) {
857+ } else if (OpType.find (" cub::Min" ) != std::string::npos ||
858+ OpType.find (" cuda::minimum" ) != std::string::npos ||
859+ OpType.find (" thrust::minimum" ) != std::string::npos) {
854860 OpRepl = MapNames::getClNamespace () + " minimum<>()" ;
855861 }
856862 };
@@ -861,17 +867,21 @@ std::string CubRule::getOpRepl(const Expr *Operator) {
861867 } else {
862868 auto CtorArg = Op->getArg (0 )->IgnoreImplicitAsWritten ();
863869 if (auto DRE = dyn_cast<DeclRefExpr>(CtorArg)) {
864- auto D = DRE->getDecl ();
865- if (!D)
866- return OpRepl;
867- std::string OpType = DpctGlobalInfo::getUnqualifiedTypeName (
868- D->getType ().getCanonicalType ());
869- if (OpType == " cub::Sum" || OpType == " cub::Max" ||
870- OpType == " cub::Min" || OpType == " cuda::std::plus<void>" ||
871- OpType == " cuda::maximum<void>" ||
872- OpType == " cuda::minimum<void>" ) {
873- ExprAnalysis EA (Operator);
874- OpRepl = EA.getReplacedString ();
870+ if (auto D = DRE->getDecl ()) {
871+ std::string OpType = DpctGlobalInfo::getUnqualifiedTypeName (
872+ D->getType ().getCanonicalType ());
873+ if (OpType.find (" cub::Sum" ) != std::string::npos ||
874+ OpType.find (" cub::Max" ) != std::string::npos ||
875+ OpType.find (" cub::Min" ) != std::string::npos ||
876+ OpType.find (" cuda::std::plus" ) != std::string::npos ||
877+ OpType.find (" cuda::maximum" ) != std::string::npos ||
878+ OpType.find (" cuda::minimum" ) != std::string::npos ||
879+ OpType.find (" thrust::plus" ) != std::string::npos ||
880+ OpType.find (" thrust::maximum" ) != std::string::npos ||
881+ OpType.find (" thrust::minimum" ) != std::string::npos) {
882+ ExprAnalysis EA (Operator);
883+ OpRepl = EA.getReplacedString ();
884+ }
875885 }
876886 } else if (auto CXXTempObj = dyn_cast<CXXTemporaryObjectExpr>(CtorArg)) {
877887 processOperatorExpr (CXXTempObj);
0 commit comments