Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
package io.codemodder.codemods;

import com.contrastsecurity.sarif.Region;
import com.github.javaparser.ast.visitor.ModifierVisitor;
import io.codemodder.ChangeConstructorTypeVisitor;
import static io.codemodder.JavaParserUtils.addImportIfMissing;

import com.contrastsecurity.sarif.Result;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import io.codemodder.Codemod;
import io.codemodder.CodemodInvocationContext;
import io.codemodder.FileWeavingContext;
import io.codemodder.ReviewGuidance;
import io.codemodder.RuleSarif;
import io.codemodder.providers.sarif.semgrep.SemgrepJavaParserChanger;
import io.codemodder.providers.sarif.semgrep.SemgrepScan;
import java.util.List;
import java.security.SecureRandom;
import javax.inject.Inject;

/** Turns {@link java.util.Random} into {@link java.security.SecureRandom}. */
@Codemod(
id = "pixee:java/secure-random",
author = "arshan@pixee.ai",
reviewGuidance = ReviewGuidance.MERGE_WITHOUT_REVIEW)
public final class SecureRandomCodemod extends SemgrepJavaParserChanger {
public final class SecureRandomCodemod extends SemgrepJavaParserChanger<ObjectCreationExpr> {

@Inject
public SecureRandomCodemod(
@SemgrepScan(pathToYaml = "/secure-random.yaml", ruleId = "secure-random")
final RuleSarif sarif) {
super(sarif);
super(sarif, ObjectCreationExpr.class);
}

@Override
public ModifierVisitor<FileWeavingContext> createVisitor(
final CodemodInvocationContext context, final List<Region> regions) {
return new ChangeConstructorTypeVisitor(
regions, "java.security.SecureRandom", context.codemodId());
public void onSemgrepResultFound(
final CodemodInvocationContext context,
final CompilationUnit cu,
final ObjectCreationExpr objectCreationExpr,
final Result result) {
objectCreationExpr.setType("SecureRandom");
addImportIfMissing(cu, SecureRandom.class.getName());
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.codemodder;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.visitor.ModifierVisitor;
import com.google.inject.AbstractModule;
import com.google.inject.Guice;
import com.google.inject.Injector;
Expand All @@ -12,7 +11,6 @@
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -64,12 +62,17 @@ public CodemodInvoker(

// validate and instantiate the codemods
Injector injector = Guice.createInjector(allModules);
Set<String> codemodIds = new HashSet<>();
List<Changer> codemods = new ArrayList<>();
for (Class<? extends Changer> type : codemodTypes) {
Codemod codemodAnnotation = type.getAnnotation(Codemod.class);
validateRequiredFields(codemodAnnotation);
Changer changer = injector.getInstance(type);
String codemodId = codemodAnnotation.id();
if (codemodIds.contains(codemodId)) {
throw new UnsupportedOperationException("multiple codemods under id: " + codemodId);
}
codemodIds.add(codemodId);
if (ruleContext.isRuleAllowed(codemodId)) {
codemods.add(changer);
changers.add(new IdentifiedChanger(codemodId, changer));
Expand Down Expand Up @@ -101,19 +104,13 @@ public void execute(final Path path, final CompilationUnit cu, final FileWeaving
for (Changer changer : codemods) {
if (changer instanceof JavaParserChanger) {
JavaParserChanger javaParserChanger = (JavaParserChanger) changer;
Optional<ModifierVisitor<FileWeavingContext>> modifierVisitor =
javaParserChanger.createModifierVisitor(
new DefaultCodemodInvocationContext(
new DefaultCodeDirectory(repositoryDir),
path,
changers.stream()
.filter(ic -> ic.changer == changer)
.findFirst()
.orElseThrow()
.id,
context));
modifierVisitor.ifPresent(
changeContextModifierVisitor -> cu.accept(changeContextModifierVisitor, context));
CodemodInvocationContext invocationContext =
new DefaultCodemodInvocationContext(
new DefaultCodeDirectory(repositoryDir),
path,
changers.stream().filter(ic -> ic.changer == changer).findFirst().orElseThrow().id,
context);
javaParserChanger.visit(invocationContext, cu);
} else {
throw new UnsupportedOperationException("unknown or not");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
package io.codemodder;

import com.github.javaparser.ast.visitor.ModifierVisitor;
import java.util.Optional;
import com.github.javaparser.ast.CompilationUnit;

/** {@inheritDoc} Uses JavaParser to change Java source files. */
public interface JavaParserChanger extends Changer {

/**
* Creates a visitor for a given Java source file, or not. It's up to the implementing type to
* determine if and how source file should be changed.
*/
Optional<ModifierVisitor<FileWeavingContext>> createModifierVisitor(
CodemodInvocationContext context);
/** */
void visit(final CodemodInvocationContext context, final CompilationUnit cu);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.codemodder;

import com.contrastsecurity.sarif.Region;
import com.github.javaparser.Position;
import com.github.javaparser.Range;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.ImportDeclaration;
Expand All @@ -25,7 +26,8 @@ public static void addImportIfMissing(final CompilationUnit cu, final String cla
if (imports.contains(newImport)) {
return;
}
for (ImportDeclaration existingImport : imports) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this since we have ASTTransforms::addImportIfMissing.

for (int i = 0; i < imports.size(); i++) {
ImportDeclaration existingImport = imports.get(i);
if (existingImport.getNameAsString().compareToIgnoreCase(className) > 0) {
imports.addBefore(newImport, existingImport);
return;
Expand All @@ -51,4 +53,16 @@ public static boolean regionMatchesNode(final Node node, final Region region) {
Range observedRange = node.getRange().get();
return observedRange.overlapsWith(sarifRange);
}

/**
* Return true if the {@link Node} is {@link Region} start at the same location.
*
* @param node the AST node to compare
* @param region the SARIF region to compare
* @return true, if the two locations have equivalent start line and columns
*/
public static boolean regionMatchesNodeStart(final Node node, final Region region) {
Position position = node.getRange().get().begin;
return region.getStartLine() == position.line && region.getStartColumn() == position.column;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.codemodder;

import com.contrastsecurity.sarif.Region;
import com.contrastsecurity.sarif.Result;
import com.contrastsecurity.sarif.SarifSchema210;
import java.nio.file.Path;
import java.util.List;
Expand All @@ -16,6 +17,14 @@ public interface RuleSarif {
*/
List<Region> getRegionsFromResultsByRule(Path path);

/**
* Get all the SARIF results with the matching path
*
* @param path the file being scanned
* @return the results associated with the given file
*/
List<Result> getResultsByPath(Path path);

/** Return the entire SARIF as a model in case more comprehensive inspection is needed. */
SarifSchema210 rawDocument();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import static io.codemodder.CodemodInvoker.isValidCodemodId;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.nio.file.Path;
import java.util.List;
import org.hamcrest.CoreMatchers;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

final class CodemodInvokerTest {

Expand All @@ -18,13 +22,13 @@ static class InvalidCodemodName implements Changer {}
id = "test_mod",
author = " ",
reviewGuidance = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW)
class EmptyCodemodAuthor implements Changer {}
static class EmptyCodemodAuthor implements Changer {}

@Codemod(
id = "pixee:java/id",
author = "valid@valid.com",
reviewGuidance = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW)
final class ValidCodemod implements Changer {}
static class ValidCodemod implements Changer {}

@Test
void it_validates_codemod_ids() {
Expand All @@ -34,4 +38,13 @@ void it_validates_codemod_ids() {
assertThat(isValidCodemodId("missing:token"), CoreMatchers.is(false));
assertThat(isValidCodemodId("missing:separator/"), CoreMatchers.is(false));
}

@Test
void it_blows_up_on_duplicate_codemod_ids(@TempDir Path tmpDir) {
assertThrows(
UnsupportedOperationException.class,
() -> {
new CodemodInvoker(List.of(ValidCodemod.class, ValidCodemod.class), tmpDir);
});
}
}
Original file line number Diff line number Diff line change
@@ -1,40 +1,64 @@
package io.codemodder.providers.sarif.semgrep;

import com.contrastsecurity.sarif.Region;
import com.github.javaparser.ast.visitor.ModifierVisitor;
import com.contrastsecurity.sarif.Result;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import io.codemodder.CodemodInvocationContext;
import io.codemodder.FileWeavingContext;
import io.codemodder.JavaParserChanger;
import io.codemodder.JavaParserUtils;
import io.codemodder.RuleSarif;
import io.codemodder.Weave;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Provides base functionality for making JavaParser-based changes with Semgrep. */
public abstract class SemgrepJavaParserChanger implements JavaParserChanger {
public abstract class SemgrepJavaParserChanger<T extends Node> implements JavaParserChanger {

protected final RuleSarif sarif;
private final Class<? extends Node> nodeType;

protected SemgrepJavaParserChanger(final RuleSarif semgrepSarif) {
protected SemgrepJavaParserChanger(
final RuleSarif semgrepSarif, final Class<? extends Node> nodeType) {
this.sarif = Objects.requireNonNull(semgrepSarif);
this.nodeType = Objects.requireNonNull(nodeType);
}

@Override
public Optional<ModifierVisitor<FileWeavingContext>> createModifierVisitor(
final CodemodInvocationContext context) {
List<Region> regions = sarif.getRegionsFromResultsByRule(context.path());
return !regions.isEmpty() ? Optional.of(createVisitor(context, regions)) : Optional.empty();
public final void visit(final CodemodInvocationContext context, final CompilationUnit cu) {

List<Result> results = sarif.getResultsByPath(context.path());
List<? extends Node> allNodes = cu.findAll(nodeType);

for (Result result : results) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine this way but a small suggestion using streams:

List<Region> regions = results.stream().map(r -> r.getLocations().get(0).getPhysicalLocation().getRegion()).collect(Collectors.toList())
var allMatchedNodes = allNodes.stream().filter(n -> regions.stream().anyMatch(r -> JavaParserUtils.regionMatchesNode(n, r)));      

Then iterate/filter/collect the resulting Stream allMatchedNodes.

for (Node node : allNodes) {
Region region = result.getLocations().get(0).getPhysicalLocation().getRegion();
if (!node.getClass().isAssignableFrom(nodeType)) {
logger.error("Unexpected node encountered in {}:{}", context.path(), region);
return;
}
if (JavaParserUtils.regionMatchesNodeStart(node, region)) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may match undesired nodes since it only checks for overlapping. For example in n1().n2().n3() the range of n1() overlaps n1().n2().n3(). This may be a particular problem if n1(), n2() and n3() are calls to the same method. I'd rather do exact matching.

Also you need to check if semgrep's region in the SARIF matches JavaParser ranges exactly. I know CodeQL does not (it includes an extra column).

onSemgrepResultFound(context, cu, (T) node, result);
FileWeavingContext changeRecorder = context.changeRecorder();
changeRecorder.addWeave(Weave.from(region.getStartLine(), context.codemodId()));
}
}
}
}

/**
* Creates a visitor for the given context and locations.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the javadoc here, It does not create a visitor anymore.

*
* @param context the context of this files transformation
* @param regions the places in this file that have been identified as needing change by the
* static analysis
* @return a visitor that will perform the necessary changes in the given source code file
* positions
* @param cu the parsed model of the file being transformed
* @param node the node to act on
* @param result the given SARIF result to act on
*/
public abstract ModifierVisitor<FileWeavingContext> createVisitor(
CodemodInvocationContext context, List<Region> regions);
public abstract void onSemgrepResultFound(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small suggestion: make it return CompilationUnit. This makes it clear that the main goal is change the AST.

CodemodInvocationContext context, CompilationUnit cu, T node, Result result);

private static final Logger logger = LoggerFactory.getLogger(SemgrepJavaParserChanger.class);
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,29 @@ public String getRule() {

@Override
public List<Region> getRegionsFromResultsByRule(final Path path) {
List<Result> resultsFilteredByRule =
sarif.getRuns().get(0).getResults().stream()
.filter(result -> result.getRuleId().endsWith("." + ruleId))
.collect(Collectors.toUnmodifiableList());
List<Result> resultsFilteredByRuleAndPath =
resultsFilteredByRule.stream()
.filter(
result -> {
String uri =
result
.getLocations()
.get(0)
.getPhysicalLocation()
.getArtifactLocation()
.getUri();
try {
return Files.isSameFile(path, Path.of(uri));
} catch (IOException e) { // this should never happen
logger.error("Problem inspecting SARIF to find code regions", e);
return false;
}
})
.collect(Collectors.toUnmodifiableList());
return resultsFilteredByRuleAndPath.stream()
return getResultsByPath(path).stream()
.filter(result -> result.getRuleId().endsWith("." + ruleId))
.map(result -> result.getLocations().get(0).getPhysicalLocation().getRegion())
.collect(Collectors.toUnmodifiableList());
}

@Override
public List<Result> getResultsByPath(final Path path) {
return sarif.getRuns().get(0).getResults().stream()
.filter(result -> result.getRuleId().endsWith("." + ruleId))
.filter(
result -> {
String uri =
result.getLocations().get(0).getPhysicalLocation().getArtifactLocation().getUri();
try {
return Files.isSameFile(path, Path.of(uri));
} catch (IOException e) { // this should never happen
logger.error("Problem inspecting SARIF to find code results", e);
return false;
}
})
.collect(Collectors.toUnmodifiableList());
}

private static final Logger logger = LoggerFactory.getLogger(SemgrepRuleSarif.class);
}
Loading