diff --git a/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyClassItem.java b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyClassItem.java new file mode 100644 index 0000000000..5752c563f0 --- /dev/null +++ b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyClassItem.java @@ -0,0 +1,42 @@ +package org.utbot.intellij.plugin.language.python.table; + +import com.intellij.icons.AllIcons; +import com.jetbrains.python.psi.PyClass; +import com.jetbrains.python.psi.PyElement; + +import javax.swing.*; + +public class UtPyClassItem implements UtPyTableItem { + private final PyClass pyClass; + private boolean isChecked; + + public UtPyClassItem(PyClass clazz) { + pyClass = clazz; + isChecked = false; + } + + @Override + public PyElement getContent() { + return pyClass; + } + + @Override + public String getIdName() { + return pyClass.getQualifiedName(); + } + + @Override + public Icon getIcon() { + return AllIcons.Nodes.Class; + } + + @Override + public boolean isChecked() { + return isChecked; + } + + @Override + public void setChecked(boolean valueToBeSet) { + isChecked = valueToBeSet; + } +} diff --git a/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyFunctionItem.java b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyFunctionItem.java new file mode 100644 index 0000000000..d3c9fdf512 --- /dev/null +++ b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyFunctionItem.java @@ -0,0 +1,42 @@ +package org.utbot.intellij.plugin.language.python.table; + +import com.intellij.icons.AllIcons; +import com.jetbrains.python.psi.PyElement; +import com.jetbrains.python.psi.PyFunction; + +import javax.swing.*; + +public class UtPyFunctionItem implements UtPyTableItem { + private final PyFunction pyFunction; + private boolean isChecked; + + public UtPyFunctionItem(PyFunction function) { + pyFunction = function; + isChecked = false; + } + + @Override + public PyElement getContent() { + return pyFunction; + } + + @Override + public String getIdName() { + return pyFunction.getQualifiedName(); + } + + @Override + public Icon getIcon() { + return AllIcons.Nodes.Function; + } + + @Override + public boolean isChecked() { + return isChecked; + } + + @Override + public void setChecked(boolean valueToBeSet) { + isChecked = valueToBeSet; + } +} diff --git a/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyMemberSelectionTable.java b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyMemberSelectionTable.java new file mode 100644 index 0000000000..c1d88eae78 --- /dev/null +++ b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyMemberSelectionTable.java @@ -0,0 +1,224 @@ +package org.utbot.intellij.plugin.language.python.table; + +import com.intellij.openapi.actionSystem.CommonDataKeys; +import com.intellij.openapi.actionSystem.DataProvider; +import com.intellij.refactoring.ui.EnableDisableAction; +import com.intellij.ui.*; +import com.intellij.ui.icons.RowIcon; +import com.intellij.ui.table.JBTable; +import com.intellij.util.containers.ContainerUtil; +import com.intellij.util.ui.JBUI; +import org.jetbrains.annotations.NonNls; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import javax.swing.*; +import javax.swing.table.AbstractTableModel; +import javax.swing.table.TableColumn; +import javax.swing.table.TableColumnModel; +import java.awt.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +public class UtPyMemberSelectionTable extends JBTable implements DataProvider { + protected static final int CHECKED_COLUMN = 0; + protected static final int DISPLAY_NAME_COLUMN = 1; + protected static final int ICON_POSITION = 0; + + protected List myItems; + protected MyTableModel myTableModel; + + public UtPyMemberSelectionTable(Collection items) { + myItems = new ArrayList<>(items); + myTableModel = new MyTableModel<>(this); + setModel(myTableModel); + + TableColumnModel model = getColumnModel(); + model.getColumn(DISPLAY_NAME_COLUMN).setCellRenderer(new MyTableRenderer<>(this)); + TableColumn checkBoxColumn = model.getColumn(CHECKED_COLUMN); + TableUtil.setupCheckboxColumn(checkBoxColumn); + checkBoxColumn.setCellRenderer(new MyBooleanRenderer<>(this)); + setPreferredScrollableViewportSize(JBUI.size(400, -1)); + setVisibleRowCount(12); + getSelectionModel().setSelectionMode(ListSelectionModel.MULTIPLE_INTERVAL_SELECTION); + setShowGrid(false); + setIntercellSpacing(new Dimension(0, 0)); + new MyEnableDisableAction().register(); + } + + public void setItems(Collection items) { + myItems = new ArrayList<>(items); + } + + @Override + public @Nullable Object getData(@NotNull @NonNls String dataId) { + if (CommonDataKeys.PSI_ELEMENT.is(dataId)) { + return ContainerUtil.getFirstItem(getSelectedMemberInfos()); + } + return null; + } + + public Collection getSelectedMemberInfos() { + ArrayList list = new ArrayList<>(myItems.size()); + for (T info : myItems) { + if (info.isChecked()) { + list.add(info); + } + } + return list; + } + + private class MyEnableDisableAction extends EnableDisableAction { + + @Override + protected JTable getTable() { + return UtPyMemberSelectionTable.this; + } + + @Override + protected void applyValue(int[] rows, boolean valueToBeSet) { + for (int row : rows) { + final T memberInfo = myItems.get(row); + memberInfo.setChecked(valueToBeSet); + } + final int[] selectedRows = getSelectedRows(); + final ListSelectionModel selectionModel = getSelectionModel(); + for (int selectedRow : selectedRows) { + selectionModel.addSelectionInterval(selectedRow, selectedRow); + } + } + + @Override + protected boolean isRowChecked(final int row) { + return myItems.get(row).isChecked(); + } + } + + private static class MyBooleanRenderer extends BooleanTableCellRenderer { + private final UtPyMemberSelectionTable myTable; + + MyBooleanRenderer(UtPyMemberSelectionTable table) { + myTable = table; + } + + @Override + public Component getTableCellRendererComponent(JTable table, Object value, boolean isSelected, boolean hasFocus, int row, int column) { + Component component = super.getTableCellRendererComponent(table, value, isSelected, hasFocus, row, column); + if (component instanceof JCheckBox) { + int modelColumn = myTable.convertColumnIndexToModel(column); + T itemInfo = myTable.myItems.get(row); + component.setEnabled(modelColumn == CHECKED_COLUMN || itemInfo.isChecked()); + } + return component; + } + } + + private static class MyTableRenderer extends ColoredTableCellRenderer { + private final UtPyMemberSelectionTable myTable; + + MyTableRenderer(UtPyMemberSelectionTable table) { + myTable = table; + } + + @Override + public void customizeCellRenderer(@NotNull JTable table, final Object value, + boolean isSelected, boolean hasFocus, final int row, final int column) { + + final int modelColumn = myTable.convertColumnIndexToModel(column); + final T item = myTable.myItems.get(row); + if (modelColumn == DISPLAY_NAME_COLUMN) { + Icon itemIcon = item.getIcon(); + RowIcon icon = IconManager.getInstance().createRowIcon(3); + icon.setIcon(itemIcon, ICON_POSITION); + setIcon(icon); + } + else { + setIcon(null); + } + setIconOpaque(false); + setOpaque(false); + + if (value == null) return; + append((String)value); + } + + } + + protected static class MyTableModel extends AbstractTableModel { + private final UtPyMemberSelectionTable myTable; + private Boolean removePrefix; + + public MyTableModel(UtPyMemberSelectionTable table) { + myTable = table; + } + + private void initRemovePrefix() { + List names = new ArrayList<>(); + for (UtPyTableItem item: myTable.myItems) { + names.add(item.getIdName()); + } + removePrefix = Utils.haveCommonPrefix(names); + } + + @Override + public int getColumnCount() { + return 2; + } + + @Override + public int getRowCount() { + return myTable.myItems.size(); + } + + @Override + public Class getColumnClass(int columnIndex) { + if (columnIndex == CHECKED_COLUMN) { + return Boolean.class; + } + return super.getColumnClass(columnIndex); + } + + @Override + public Object getValueAt(int rowIndex, int columnIndex) { + if (removePrefix == null) { + initRemovePrefix(); + } + final T itemInfo = myTable.myItems.get(rowIndex); + if (columnIndex == CHECKED_COLUMN) { + return itemInfo.isChecked(); + } else if (columnIndex == DISPLAY_NAME_COLUMN) { + if (removePrefix) { + return Utils.getSuffix(itemInfo.getIdName()); + } + return itemInfo.getIdName(); + } else { + throw new RuntimeException("Incorrect column index"); + } + } + + @Override + public String getColumnName(int column) { + if (column == CHECKED_COLUMN) { + return " "; + } else if (column == DISPLAY_NAME_COLUMN) { + return "Members"; + } else { + throw new RuntimeException("Incorrect column index"); + } + } + + @Override + public boolean isCellEditable(int rowIndex, int columnIndex) { + return columnIndex == CHECKED_COLUMN; + } + + + @Override + public void setValueAt(final Object aValue, final int rowIndex, final int columnIndex) { + if (columnIndex == CHECKED_COLUMN) { + myTable.myItems.get(rowIndex).setChecked((Boolean) aValue); + } + } + } +} diff --git a/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyTableItem.java b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyTableItem.java new file mode 100644 index 0000000000..9d3da4e6f1 --- /dev/null +++ b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/UtPyTableItem.java @@ -0,0 +1,18 @@ +package org.utbot.intellij.plugin.language.python.table; + +import com.jetbrains.python.psi.PyElement; + +import javax.swing.*; + +public interface UtPyTableItem { + + public PyElement getContent(); + + public String getIdName(); + + public Icon getIcon(); + + boolean isChecked(); + + void setChecked(boolean valueToBeSet); +} diff --git a/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/Utils.java b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/Utils.java new file mode 100644 index 0000000000..7667a77c50 --- /dev/null +++ b/utbot-intellij-python/src/main/java/org/utbot/intellij/plugin/language/python/table/Utils.java @@ -0,0 +1,27 @@ +package org.utbot.intellij.plugin.language.python.table; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class Utils { + public static Boolean haveCommonPrefix(List strings) { + Set prefixes = new HashSet<>(); + for (String str: strings) { + prefixes.add(getPrefix(str)); + } + return prefixes.size() <= 1; + } + + public static String getPrefix(String str) { + String suffix = getSuffix(str); + int len = str.length(); + return str.substring(0, len-suffix.length()-1); + } + + public static String getSuffix(String str) { + String[] parts = str.split("\\."); + int len = parts.length; + return parts[len-1]; + } +} diff --git a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonDialogProcessor.kt b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonDialogProcessor.kt index f8aa3ffb6d..2dd06c4951 100644 --- a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonDialogProcessor.kt +++ b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonDialogProcessor.kt @@ -6,6 +6,7 @@ import com.intellij.openapi.application.readAction import com.intellij.openapi.editor.Editor import com.intellij.openapi.fileEditor.FileDocumentManager import com.intellij.openapi.module.Module +import com.intellij.openapi.module.ModuleUtilCore import com.intellij.openapi.progress.ProgressIndicator import com.intellij.openapi.progress.ProgressManager import com.intellij.openapi.progress.Task.Backgroundable @@ -19,12 +20,15 @@ import com.intellij.openapi.vfs.VirtualFile import com.intellij.psi.PsiDirectory import com.intellij.psi.PsiFileFactory import com.jetbrains.python.psi.PyClass +import com.jetbrains.python.psi.PyElement import com.jetbrains.python.psi.PyFile import com.jetbrains.python.psi.PyFunction +import com.jetbrains.python.psi.resolve.QualifiedNameFinder import kotlinx.coroutines.runBlocking import org.jetbrains.kotlin.idea.util.application.runWriteAction import org.jetbrains.kotlin.idea.util.module import org.jetbrains.kotlin.idea.util.projectStructure.sdk +import org.jetbrains.kotlin.j2k.getContainingClass import org.utbot.common.PathUtil.toPath import org.utbot.common.appendHtmlLine import org.utbot.framework.UtSettings @@ -49,10 +53,8 @@ const val DEFAULT_TIMEOUT_FOR_RUN_IN_MILLIS = 2000L object PythonDialogProcessor { fun createDialogAndGenerateTests( project: Project, - functionsToShow: Set, - containingClass: PyClass?, - focusedMethod: PyFunction?, - file: PyFile, + elementsToShow: Set, + focusedElement: PyElement?, editor: Editor? = null, ) { editor?.let{ @@ -62,7 +64,7 @@ object PythonDialogProcessor { } } } - val pythonPath = getPythonPath(functionsToShow) + val pythonPath = getPythonPath(elementsToShow) if (pythonPath == null) { showErrorDialogLater( project, @@ -72,10 +74,8 @@ object PythonDialogProcessor { } else { val dialog = createDialog( project, - functionsToShow, - containingClass, - focusedMethod, - file, + elementsToShow, + focusedElement, pythonPath, ) if (!dialog.showAndGet()) { @@ -85,54 +85,51 @@ object PythonDialogProcessor { } } - private fun getPythonPath(functionsToShow: Set): String? { - return findSrcModule(functionsToShow).sdk?.homePath + private fun getPythonPath(elementsToShow: Set): String? { + return findSrcModules(elementsToShow).first().sdk?.homePath } private fun createDialog( project: Project, - functionsToShow: Set, - containingClass: PyClass?, - focusedMethod: PyFunction?, - file: PyFile, + elementsToShow: Set, + focusedElement: PyElement?, pythonPath: String, ): PythonDialogWindow { - val srcModule = findSrcModule(functionsToShow) - val testModules = srcModule.testModules(project) - val (directoriesForSysPath, moduleToImport) = getDirectoriesForSysPath(srcModule, file) + val srcModules = findSrcModules(elementsToShow) + val testModules = srcModules.flatMap {it.testModules(project)} + val focusedElements = focusedElement + ?.let { setOf(focusedElement.toUtPyTableItem()).filterNotNull() } + ?.toSet() return PythonDialogWindow( PythonTestsModel( project, - srcModule, + srcModules.first(), testModules, - functionsToShow, - containingClass, - if (focusedMethod != null) setOf(focusedMethod) else null, - file, - directoriesForSysPath, - moduleToImport, + elementsToShow, + focusedElements, UtSettings.utBotGenerationTimeoutInMillis, DEFAULT_TIMEOUT_FOR_RUN_IN_MILLIS, cgLanguageAssistant = PythonCgLanguageAssistant, pythonPath = pythonPath, + names = elementsToShow.associateBy { Pair(it.fileName()!!, it.name!!) }, ) ) } - private fun findSelectedPythonMethods(model: PythonTestsModel): List { + private fun findSelectedPythonMethods(model: PythonTestLocalModel): List { return runBlocking { readAction { - val allFunctions: List = - if (model.containingClass == null) { - model.file.topLevelFunctions - } else { - val classes = model.file.topLevelClasses - val myClass = classes.find { it.name == model.containingClass.name } - ?: error("Didn't find containing class") - myClass.methods.filterNotNull() + model.selectedElements + .filter { model.selectedElements.contains(it) } + .flatMap { + when (it) { + is PyFunction -> listOf(it) + is PyClass -> it.methods.toList() + else -> emptyList() + } } - val shownFunctions: Set = allFunctions + .filter { fineFunction(it) } .mapNotNull { val functionName = it.name ?: return@mapNotNull null val moduleFilename = it.containingFile.virtualFile?.canonicalPath ?: "" @@ -144,70 +141,117 @@ object PythonDialogProcessor { ) } .toSet() + .toList() + } + } + } - model.selectedFunctions.map { pyFunction -> - shownFunctions.find { pythonMethod -> - pythonMethod.name == pyFunction.name - } ?: error("Didn't find PythonMethod ${pyFunction.name}") - } + private fun groupPyElementsByModule(model: PythonTestsModel): Set { + return runBlocking { + readAction { + model.selectedElements + .groupBy { it.containingFile } + .flatMap { fileGroup -> + fileGroup.value + .groupBy { it is PyClass }.values + } + .filter { it.isNotEmpty() } + .map { + val realElements = it.map { member -> model.names[Pair(member.fileName(), member.name)]!! } + val file = realElements.first().containingFile as PyFile + val srcModule = getSrcModule(realElements.first()) + + val (directoriesForSysPath, moduleToImport) = getDirectoriesForSysPath(srcModule, file) + PythonTestLocalModel( + model.project, + model.timeout, + model.timeoutForRun, + model.cgLanguageAssistant, + model.pythonPath, + model.testSourceRootPath, + model.testFramework, + realElements.toSet(), + model.runtimeExceptionTestsBehaviour, + directoriesForSysPath, + moduleToImport, + file, + realElements.first().getContainingClass() as PyClass? + ) + } + .toSet() } } } - private fun getOutputFileName(model: PythonTestsModel) = - "test_${model.currentPythonModule.camelToSnakeCase().replace('.', '_')}.py" + private fun getOutputFileName(model: PythonTestLocalModel): String { + val moduleName = model.currentPythonModule.camelToSnakeCase().replace('.', '_') + return if (model.containingClass == null) { + "test_$moduleName.py" + } else { + val className = model.containingClass.name?.camelToSnakeCase()?.replace('.', '_') + "test_${moduleName}_$className.py" + } + } - private fun createTests(project: Project, model: PythonTestsModel) { + private fun createTests(project: Project, baseModel: PythonTestsModel) { ProgressManager.getInstance().run(object : Backgroundable(project, "Generate python tests") { override fun run(indicator: ProgressIndicator) { if (!LockFile.lock()) { return } try { - val methods = findSelectedPythonMethods(model) - val requirementsList = requirements.toMutableList() - if (!model.testFramework.isInstalled) { - requirementsList += model.testFramework.mainPackage - } + groupPyElementsByModule(baseModel).forEach { model -> + val methods = findSelectedPythonMethods(model) + val requirementsList = requirements.toMutableList() + if (!model.testFramework.isInstalled) { + requirementsList += model.testFramework.mainPackage + } - processTestGeneration( - pythonPath = model.pythonPath, - pythonFilePath = model.file.virtualFile.path, - pythonFileContent = getContentFromPyFile(model.file), - directoriesForSysPath = model.directoriesForSysPath, - currentPythonModule = model.currentPythonModule, - pythonMethods = methods, - containingClassName = model.containingClass?.name, - timeout = model.timeout, - testFramework = model.testFramework, - timeoutForRun = model.timeoutForRun, - writeTestTextToFile = { generatedCode -> - writeGeneratedCodeToPsiDocument(generatedCode, model) - }, - pythonRunRoot = Path(model.testSourceRootPath), - isCanceled = { indicator.isCanceled }, - checkingRequirementsAction = { indicator.text = "Checking requirements" }, - installingRequirementsAction = { indicator.text = "Installing requirements..." }, - requirementsAreNotInstalledAction = { - askAndInstallRequirementsLater(model.project, model.pythonPath, requirementsList) - PythonTestGenerationProcessor.MissingRequirementsActionResult.NOT_INSTALLED - }, - startedLoadingPythonTypesAction = { indicator.text = "Loading information about Python types" }, - startedTestGenerationAction = { indicator.text = "Generating tests" }, - notGeneratedTestsAction = { - showErrorDialogLater( - project, - message = "Cannot create tests for the following functions: " + it.joinToString(), - title = "Python test generation error" - ) - }, - processMypyWarnings = { - val message = it.fold(StringBuilder()) { acc, line -> acc.appendHtmlLine(line) } - WarningTestsReportNotifier.notify(message.toString()) - }, - runtimeExceptionTestsBehaviour = model.runtimeExceptionTestsBehaviour, - startedCleaningAction = { indicator.text = "Cleaning up..." } - ) + val content = runBlocking { + readAction { + getContentFromPyFile(model.file) + } + } + + processTestGeneration( + pythonPath = model.pythonPath, + pythonFilePath = model.file.virtualFile.path, + pythonFileContent = content, + directoriesForSysPath = model.directoriesForSysPath, + currentPythonModule = model.currentPythonModule, + pythonMethods = methods, + containingClassName = model.containingClass?.name, + timeout = model.timeout, + testFramework = model.testFramework, + timeoutForRun = model.timeoutForRun, + writeTestTextToFile = { generatedCode -> + writeGeneratedCodeToPsiDocument(generatedCode, model) + }, + pythonRunRoot = Path(model.testSourceRootPath), + isCanceled = { indicator.isCanceled }, + checkingRequirementsAction = { indicator.text = "Checking requirements" }, + installingRequirementsAction = { indicator.text = "Installing requirements..." }, + requirementsAreNotInstalledAction = { + askAndInstallRequirementsLater(model.project, model.pythonPath, requirementsList) + PythonTestGenerationProcessor.MissingRequirementsActionResult.NOT_INSTALLED + }, + startedLoadingPythonTypesAction = { indicator.text = "Loading information about Python types" }, + startedTestGenerationAction = { indicator.text = "Generating tests" }, + notGeneratedTestsAction = { + showErrorDialogLater( + project, + message = "Cannot create tests for the following functions: " + it.joinToString(), + title = "Python test generation error" + ) + }, + processMypyWarnings = { + val message = it.fold(StringBuilder()) { acc, line -> acc.appendHtmlLine(line) } + WarningTestsReportNotifier.notify(message.toString()) + }, + runtimeExceptionTestsBehaviour = model.runtimeExceptionTestsBehaviour, + startedCleaningAction = { indicator.text = "Cleaning up..." } + ) + } } finally { LockFile.unlock() } @@ -221,7 +265,7 @@ object PythonDialogProcessor { return getDirectoriesFromRoot(root, path.parent) + listOf(path.fileName.toString()) } - private fun createPsiDirectoryForTestSourceRoot(model: PythonTestsModel): PsiDirectory { + private fun createPsiDirectoryForTestSourceRoot(model: PythonTestLocalModel): PsiDirectory { val root = getContentRoot(model.project, model.file.virtualFile) val paths = getDirectoriesFromRoot( Paths.get(root.path), @@ -233,7 +277,7 @@ object PythonDialogProcessor { } } - private fun writeGeneratedCodeToPsiDocument(generatedCode: String, model: PythonTestsModel) { + private fun writeGeneratedCodeToPsiDocument(generatedCode: String, model: PythonTestLocalModel) { invokeLater { runWriteAction { val testDir = createPsiDirectoryForTestSourceRoot(model) @@ -288,13 +332,16 @@ object PythonDialogProcessor { } } -fun findSrcModule(functions: Collection): Module { - val srcModules = functions.mapNotNull { it.module }.distinct() - return when (srcModules.size) { - 0 -> error("Module for source classes not found") - 1 -> srcModules.first() - else -> error("Can not generate tests for classes from different modules") - } +fun findSrcModules(elements: Collection): List { + return elements.mapNotNull { it.module }.distinct() +} + +fun getSrcModule(element: PyElement): Module { + return ModuleUtilCore.findModuleForPsiElement(element) ?: error("Module for source class or function not found") +} + +fun getFullName(element: PyElement): String { + return QualifiedNameFinder.getQualifiedName(element) ?: error("Name for source class or function not found") } fun getContentFromPyFile(file: PyFile) = file.viewProvider.contents.toString() @@ -306,70 +353,76 @@ fun getDirectoriesForSysPath( srcModule: Module, file: PyFile ): Pair, String> { - val sources = ModuleRootManager.getInstance(srcModule).getSourceRoots(false).toMutableList() - val ancestor = ProjectFileIndex.getInstance(file.project).getContentRootForFile(file.virtualFile) - if (ancestor != null) - sources.add(ancestor) + return runBlocking { + readAction { + val sources = ModuleRootManager.getInstance(srcModule).getSourceRoots(false).toMutableList() + val ancestor = ProjectFileIndex.getInstance(file.project).getContentRootForFile(file.virtualFile) + if (ancestor != null) + sources.add(ancestor) - // Collect sys.path directories with imported modules - val importedPaths = emptyList().toMutableList() + // Collect sys.path directories with imported modules + val importedPaths = emptyList().toMutableList() - // 1. import - file.importTargets.forEach { importTarget -> - importTarget.multiResolve().forEach { - val element = it.element - if (element != null) { - val directory = element.parent - if (directory is PsiDirectory) { - // If we have `import a.b.c` we need to add syspath to module `a` only - val additionalLevel = importTarget.importedQName?.componentCount?.dec() ?: 0 - directory.topParent(additionalLevel)?.let { dir -> - importedPaths.add(dir.virtualFile) + // 1. import + file.importTargets.forEach { importTarget -> + importTarget.multiResolve().forEach { + val element = it.element + if (element != null) { + val directory = element.parent + if (directory is PsiDirectory) { + // If we have `import a.b.c` we need to add syspath to module `a` only + val additionalLevel = importTarget.importedQName?.componentCount?.dec() ?: 0 + directory.topParent(additionalLevel)?.let { dir -> + importedPaths.add(dir.virtualFile) + } + } } } } - } - } - // 2. from import ... - file.fromImports.forEach { importTarget -> - importTarget.resolveImportSourceCandidates().forEach { - val directory = it.parent - val isRelativeImport = importTarget.relativeLevel > 0 // If we have `from . import a` we don't need to add syspath - if (directory is PsiDirectory && !isRelativeImport) { - // If we have `from a.b.c import d` we need to add syspath to module `a` only - val additionalLevel = importTarget.importSourceQName?.componentCount?.dec() ?: 0 - directory.topParent(additionalLevel)?.let { dir -> - importedPaths.add(dir.virtualFile) + // 2. from import ... + file.fromImports.forEach { importTarget -> + importTarget.resolveImportSourceCandidates().forEach { + val directory = it.parent + val isRelativeImport = + importTarget.relativeLevel > 0 // If we have `from . import a` we don't need to add syspath + if (directory is PsiDirectory && !isRelativeImport) { + // If we have `from a.b.c import d` we need to add syspath to module `a` only + val additionalLevel = importTarget.importSourceQName?.componentCount?.dec() ?: 0 + directory.topParent(additionalLevel)?.let { dir -> + importedPaths.add(dir.virtualFile) + } + } } } - } - } - // Select modules only from this project but not from installation directory - importedPaths.forEach { - val path = it.toNioPath() - val hasSitePackages = (0 until(path.nameCount)).any { i -> path.subpath(i, i+1).toString() == "site-packages"} - if (it.isProjectSubmodule(ancestor) && !hasSitePackages) { - sources.add(it) - } - } + // Select modules only from this project but not from installation directory + importedPaths.forEach { + val path = it.toNioPath() + val hasSitePackages = + (0 until (path.nameCount)).any { i -> path.subpath(i, i + 1).toString() == "site-packages" } + if (it.isProjectSubmodule(ancestor) && !hasSitePackages) { + sources.add(it) + } + } - val fileName = file.name.removeSuffix(".py") - val importPath = ancestor?.let { - VfsUtil.getParentDir( - VfsUtilCore.getRelativeLocation(file.virtualFile, it) - ) - } ?: "" - val importStringPath = listOf( - importPath.toPath().joinToString("."), - fileName - ) - .filterNot { it.isEmpty() } - .joinToString(".") + val fileName = file.name.removeSuffix(".py") + val importPath = ancestor?.let { + VfsUtil.getParentDir( + VfsUtilCore.getRelativeLocation(file.virtualFile, it) + ) + } ?: "" + val importStringPath = listOf( + importPath.toPath().joinToString("."), + fileName + ) + .filterNot { it.isEmpty() } + .joinToString(".") - return Pair( - sources.map { it.path }.toSet(), - importStringPath - ) + Pair( + sources.map { it.path }.toSet(), + importStringPath + ) + } + } } \ No newline at end of file diff --git a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonDialogWindow.kt b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonDialogWindow.kt index 7f15e519f7..5cfc2c3555 100644 --- a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonDialogWindow.kt +++ b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonDialogWindow.kt @@ -11,11 +11,12 @@ import com.intellij.ui.layout.Row import com.intellij.ui.layout.panel import com.intellij.util.ui.JBUI import com.jetbrains.python.psi.* -import com.jetbrains.python.refactoring.classes.PyMemberInfoStorage -import com.jetbrains.python.refactoring.classes.membersManager.PyMemberInfo -import com.jetbrains.python.refactoring.classes.ui.PyMemberSelectionTable import org.utbot.framework.UtSettings import org.utbot.framework.codegen.domain.ProjectType +import org.utbot.intellij.plugin.language.python.table.UtPyClassItem +import org.utbot.intellij.plugin.language.python.table.UtPyFunctionItem +import org.utbot.intellij.plugin.language.python.table.UtPyMemberSelectionTable +import org.utbot.intellij.plugin.language.python.table.UtPyTableItem import org.utbot.intellij.plugin.settings.Settings import java.awt.BorderLayout import java.util.concurrent.TimeUnit @@ -27,18 +28,19 @@ import javax.swing.* private const val WILL_BE_INSTALLED_LABEL = " (will be installed)" private const val MINIMUM_TIMEOUT_VALUE_IN_SECONDS = 1 +private const val STEP_TIMEOUT_VALUE_IN_SECONDS = 5 private const val ACTION_GENERATE = "Generate Tests" class PythonDialogWindow(val model: PythonTestsModel) : DialogWrapper(model.project) { - private val functionsTable = PyMemberSelectionTable(emptyList(), null, false) - private val testSourceFolderField = TestSourceDirectoryChooser(model, model.file.virtualFile) + private val pyElementsTable = UtPyMemberSelectionTable(emptyList()) + private val testSourceFolderField = TestSourceDirectoryChooser(model) private val timeoutSpinnerForTotalTimeout = JBIntSpinner( TimeUnit.MILLISECONDS.toSeconds(UtSettings.utBotGenerationTimeoutInMillis).toInt(), MINIMUM_TIMEOUT_VALUE_IN_SECONDS, Int.MAX_VALUE, - MINIMUM_TIMEOUT_VALUE_IN_SECONDS + STEP_TIMEOUT_VALUE_IN_SECONDS ) private val testFrameworks = ComboBox(DefaultComboBoxModel(model.cgLanguageAssistant.getLanguageTestFrameworkManager().testFrameworks.toTypedArray())) @@ -77,11 +79,11 @@ class PythonDialogWindow(val model: PythonTestsModel) : DialogWrapper(model.proj } row("Generate test methods for:") {} row { - scrollPane(functionsTable) + scrollPane(pyElementsTable) } } - updateFunctionsTable() + updatePyElementsTable() updateTestFrameworksList() return panel } @@ -90,47 +92,27 @@ class PythonDialogWindow(val model: PythonTestsModel) : DialogWrapper(model.proj testFrameworks.renderer = createTestFrameworksRenderer(WILL_BE_INSTALLED_LABEL) } - private fun globalPyFunctionsToPyMemberInfo( - project: Project, - functions: Collection - ): List> { - val generator = PyElementGenerator.getInstance(project) - val fakeClassName = generateRandomString(15) - val newClass = generator.createFromText( - LanguageLevel.getDefault(), - PyClass::class.java, - "class __FakeWrapperUtBotClass_$fakeClassName:\npass" - ) - functions.forEach { - newClass.add(it) - } - val storage = PyMemberInfoStorage(newClass) - return storage.getClassMemberInfos(newClass) - } - - private fun pyFunctionsToPyMemberInfo( - project: Project, - functions: Collection, - containingClass: PyClass? - ): List> { - if (containingClass == null) { - return globalPyFunctionsToPyMemberInfo(project, functions) + private fun updatePyElementsTable() { + val functions = model.elementsToDisplay.filterIsInstance() + val classes = model.elementsToDisplay.filterIsInstance() + val functionItems = functions + .groupBy { it.containingClass } + .flatMap { (_, pyFuncs) -> + pyFuncs.map {UtPyFunctionItem(it)} + } + val classItems = classes.map { + UtPyClassItem(it) } - return PyMemberInfoStorage(containingClass).getClassMemberInfos(containingClass) - .filter { it.member is PyFunction && fineFunction(it.member as PyFunction) } - } - - private fun updateFunctionsTable() { - val items = pyFunctionsToPyMemberInfo(model.project, model.functionsToDisplay, model.containingClass) + val items = classItems + functionItems updateMethodsTable(items) - val height = functionsTable.rowHeight * (items.size.coerceAtMost(12) + 1) - functionsTable.preferredScrollableViewportSize = JBUI.size(-1, height) + val height = pyElementsTable.rowHeight * (items.size.coerceAtMost(12) + 1) + pyElementsTable.preferredScrollableViewportSize = JBUI.size(-1, height) } - private fun updateMethodsTable(allMethods: Collection>) { - val focusedNames = model.focusedMethod?.map { it.name } + private fun updateMethodsTable(allMethods: Collection) { + val focusedNames = model.focusedElements?.map { it.idName } val selectedMethods = allMethods.filter { - focusedNames?.contains(it.member.name) ?: false + focusedNames?.contains(it.idName) ?: false } if (selectedMethods.isEmpty()) { @@ -139,10 +121,10 @@ class PythonDialogWindow(val model: PythonTestsModel) : DialogWrapper(model.proj checkMembers(selectedMethods) } - functionsTable.setMemberInfos(allMethods) + pyElementsTable.setItems(allMethods) } - private fun checkMembers(members: Collection>) = members.forEach { it.isChecked = true } + private fun checkMembers(members: Collection) = members.forEach { it.isChecked = true } private fun Row.makePanelWithHelpTooltip( mainComponent: JComponent, @@ -167,8 +149,8 @@ class PythonDialogWindow(val model: PythonTestsModel) : DialogWrapper(model.proj override fun getOKAction() = okOptionAction override fun doOKAction() { - val selectedMembers = functionsTable.selectedMemberInfos - model.selectedFunctions = selectedMembers.mapNotNull { it.member as? PyFunction }.toSet() + val selectedMembers = pyElementsTable.selectedMemberInfos + model.selectedElements = selectedMembers.mapNotNull { it.content }.toSet() model.testFramework = testFrameworks.item model.timeout = TimeUnit.SECONDS.toMillis(timeoutSpinnerForTotalTimeout.number.toLong()) model.testSourceRootPath = testSourceFolderField.text diff --git a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonLanguageAssistant.kt b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonLanguageAssistant.kt index 6624121f28..2c0f5fec4b 100644 --- a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonLanguageAssistant.kt +++ b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonLanguageAssistant.kt @@ -3,15 +3,20 @@ package org.utbot.intellij.plugin.language.python import com.intellij.lang.Language import com.intellij.openapi.actionSystem.AnActionEvent import com.intellij.openapi.actionSystem.CommonDataKeys +import com.intellij.openapi.actionSystem.PlatformDataKeys import com.intellij.openapi.editor.Editor +import com.intellij.openapi.project.Project +import com.intellij.openapi.vfs.VirtualFile +import com.intellij.psi.PsiDirectory import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile +import com.intellij.psi.PsiFileSystemItem +import com.intellij.psi.util.PsiTreeUtil import com.jetbrains.python.psi.PyClass import com.jetbrains.python.psi.PyFile import com.jetbrains.python.psi.PyFunction -import com.jetbrains.python.sdk.PythonSdkType -import org.jetbrains.kotlin.idea.util.projectStructure.module -import org.jetbrains.kotlin.idea.util.projectStructure.sdk +import org.jetbrains.kotlin.idea.core.util.toPsiDirectory +import org.jetbrains.kotlin.idea.core.util.toPsiFile import org.utbot.framework.plugin.api.util.LockFile import org.utbot.intellij.plugin.language.agnostic.LanguageAssistant @@ -21,24 +26,26 @@ object PythonLanguageAssistant : LanguageAssistant() { val language: Language = Language.findLanguageByID(pythonID) ?: error("Language wasn't found") data class Targets( - val functions: Set, - val containingClass: PyClass?, + val pyClasses: Set, + val pyFunctions: Set, + val focusedClass: PyClass?, val focusedFunction: PyFunction?, - val file: PyFile, - val editor: Editor?, - ) + val editor: Editor? + ) { + override fun toString(): String { + return "Targets($pyClasses, $pyFunctions)" + } + } override fun actionPerformed(e: AnActionEvent) { val project = e.project ?: return - val (functions, containingClass, focusedFunction, file, editor) = getPsiTargets(e) ?: return + val targets = getPsiTargets(e) ?: return PythonDialogProcessor.createDialogAndGenerateTests( project, - functions, - containingClass, - focusedFunction, - file, - editor, + targets.pyClasses + targets.pyFunctions, + targets.focusedClass ?: targets.focusedFunction, + targets.editor, ) } @@ -47,45 +54,77 @@ object PythonLanguageAssistant : LanguageAssistant() { } private fun getPsiTargets(e: AnActionEvent): Targets? { + val project = e.project ?: return null val editor = e.getData(CommonDataKeys.EDITOR) - val file = e.getData(CommonDataKeys.PSI_FILE) as? PyFile ?: return null - if (file.module?.sdk?.sdkType !is PythonSdkType) - return null - - val element = if (editor != null) { - findPsiElement(file, editor) ?: return null + val resultFunctions = mutableSetOf() + val resultClasses = mutableSetOf() + val focusedFunction: PyFunction? + var focusedClass: PyClass? = null + + if (editor != null) { + val file = e.getData(CommonDataKeys.PSI_FILE) as? PyFile ?: return null + val element = findPsiElement(file, editor) ?: return null + + val allFunctions = file.topLevelFunctions.filter { fineFunction(it) } + val allClasses = file.topLevelClasses.filter { fineClass(it) } + + val containingClass = getContainingElement(element) { fineClass(it) } + val containingFunction: PyFunction? = + if (containingClass == null) + getContainingElement(element) { it.parent is PsiFile && fineFunction(it) } + else + getContainingElement(element) { func -> + val ancestors = getAncestors(func) + ancestors.dropLast(1).all { it !is PyFunction } && + ancestors.count { it is PyClass } == 1 && fineFunction(func) + } + + if (allClasses.isEmpty()) { + return if (allFunctions.isEmpty()) { + null + } else { + resultFunctions.addAll(allFunctions) + focusedFunction = containingFunction + Targets(resultClasses, resultFunctions, null, focusedFunction, editor) + } + } else { + if (containingClass == null) { + resultClasses.addAll(allClasses) + resultFunctions.addAll(allFunctions) + focusedFunction = containingFunction + } else { + resultFunctions.addAll(containingClass.methods.filter { fineFunction(it) }) + focusedClass = containingClass + focusedFunction = containingFunction + } + return Targets(resultClasses, resultFunctions, focusedClass, focusedFunction, editor) + } } else { - e.getData(CommonDataKeys.PSI_ELEMENT) ?: return null - } - - val containingClass = getContainingElement(element) { fineClass(it) } - val containingFunction: PyFunction? = - if (containingClass == null) - getContainingElement(element) { it.parent is PsiFile && fineFunction(it) } - else - getContainingElement(element) { func -> - val ancestors = getAncestors(func) - ancestors.dropLast(1).all { it !is PyFunction } && - ancestors.count { it is PyClass } == 1 && fineFunction(func) + val element = e.getData(CommonDataKeys.PSI_ELEMENT) + if (element is PsiFileSystemItem) { + e.getData(CommonDataKeys.VIRTUAL_FILE_ARRAY)?.let { + val (classes, functions) = getAllElements(project, it.toList()) + resultFunctions.addAll(functions) + resultClasses.addAll(classes) } - - if (containingClass == null) { - val functions = file.topLevelFunctions.filter { fineFunction(it) } - if (functions.isEmpty()) - return null - - val focusedFunction = if (functions.contains(containingFunction)) containingFunction else null - return Targets(functions.toSet(), null, focusedFunction, file, editor) + } else { + val someSelection = e.getData(PlatformDataKeys.PSI_ELEMENT_ARRAY)?: return null + someSelection.forEach { + when(it) { + is PsiFileSystemItem -> { + val (classes, functions) = getAllElements(project, listOf(it.virtualFile)) + resultFunctions += functions + resultClasses += classes + } + } + } + } + if (resultClasses.isNotEmpty() || resultFunctions.isNotEmpty()) { + return Targets(resultClasses, resultFunctions, null, null, null) + } } - - val functions = containingClass.methods.filter { fineFunction(it) } - if (functions.isEmpty()) - return null - - val focusedFunction = - if (functions.any { it.name == containingFunction?.name }) containingFunction else null - return Targets(functions.toSet(), containingClass, focusedFunction, file, editor) + return null } // this method is copy-paste from GenerateTestsActions.kt @@ -98,4 +137,47 @@ object PythonLanguageAssistant : LanguageAssistant() { return element } + + private fun getAllElements(project: Project, virtualFiles: Collection): Pair, Set> { + val psiFiles = virtualFiles.mapNotNull { it.toPsiFile(project) } + val psiDirectories = virtualFiles.mapNotNull { it.toPsiDirectory(project) } + + val classes = psiFiles.flatMap { getClassesFromFile(it) }.toMutableSet() + val functions = psiFiles.flatMap { getFunctionsFromFile(it) }.toMutableSet() + + psiDirectories.forEach { + classes.addAll(getAllClasses(it)) + functions.addAll(getAllFunctions(it)) + } + + return classes to functions + } + + private fun getAllFunctions(directory: PsiDirectory): Set { + val allFunctions = directory.files.flatMap { getFunctionsFromFile(it) }.toMutableSet() + directory.subdirectories.forEach { + allFunctions.addAll(getAllFunctions(it)) + } + return allFunctions + } + + private fun getAllClasses(directory: PsiDirectory): Set { + val allClasses = directory.files.flatMap { getClassesFromFile(it) }.toMutableSet() + directory.subdirectories.forEach { + allClasses.addAll(getAllClasses(it)) + } + return allClasses + } + + private fun getFunctionsFromFile(psiFile: PsiFile): List { + return PsiTreeUtil.getChildrenOfTypeAsList(psiFile, PyFunction::class.java) + .map { it as PyFunction } + .filter { fineFunction(it) } + } + + private fun getClassesFromFile(psiFile: PsiFile): List { + return PsiTreeUtil.getChildrenOfTypeAsList(psiFile, PyClass::class.java) + .map { it as PyClass } + .filter { fineClass(it) } + } } \ No newline at end of file diff --git a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonTestsModel.kt b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonTestsModel.kt index 1ea4f094f3..f0d4ff007f 100644 --- a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonTestsModel.kt +++ b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/PythonTestsModel.kt @@ -3,27 +3,25 @@ package org.utbot.intellij.plugin.language.python import com.intellij.openapi.module.Module import com.intellij.openapi.project.Project import com.jetbrains.python.psi.PyClass +import com.jetbrains.python.psi.PyElement import com.jetbrains.python.psi.PyFile -import com.jetbrains.python.psi.PyFunction import org.utbot.framework.codegen.domain.RuntimeExceptionTestsBehaviour import org.utbot.framework.codegen.domain.TestFramework import org.utbot.framework.codegen.services.language.CgLanguageAssistant +import org.utbot.intellij.plugin.language.python.table.UtPyTableItem import org.utbot.intellij.plugin.models.BaseTestsModel class PythonTestsModel( project: Project, srcModule: Module, potentialTestModules: List, - val functionsToDisplay: Set, - val containingClass: PyClass?, - val focusedMethod: Set?, - val file: PyFile, - val directoriesForSysPath: Set, - val currentPythonModule: String, + val elementsToDisplay: Set, + val focusedElements: Set?, var timeout: Long, var timeoutForRun: Long, val cgLanguageAssistant: CgLanguageAssistant, val pythonPath: String, + val names: Map, PyElement>, ) : BaseTestsModel( project, srcModule, @@ -31,6 +29,22 @@ class PythonTestsModel( ) { lateinit var testSourceRootPath: String lateinit var testFramework: TestFramework - lateinit var selectedFunctions: Set + var selectedElements: Set = emptySet() lateinit var runtimeExceptionTestsBehaviour: RuntimeExceptionTestsBehaviour } + +data class PythonTestLocalModel( + val project: Project, + val timeout: Long, + val timeoutForRun: Long, + val cgLanguageAssistant: CgLanguageAssistant, + val pythonPath: String, + val testSourceRootPath: String, + val testFramework: TestFramework, + val selectedElements: Set, + val runtimeExceptionTestsBehaviour: RuntimeExceptionTestsBehaviour, + val directoriesForSysPath: Set, + val currentPythonModule: String, + val file: PyFile, + val containingClass: PyClass?, +) \ No newline at end of file diff --git a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/Utils.kt b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/Utils.kt index 0b73c99e9a..e8b285afc6 100644 --- a/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/Utils.kt +++ b/utbot-intellij-python/src/main/kotlin/org/utbot/intellij/plugin/language/python/Utils.kt @@ -7,8 +7,11 @@ import com.intellij.openapi.vfs.VirtualFile import com.intellij.psi.PsiDirectory import com.intellij.psi.PsiElement import com.jetbrains.python.psi.PyClass -import com.jetbrains.python.psi.PyDecorator +import com.jetbrains.python.psi.PyElement import com.jetbrains.python.psi.PyFunction +import org.utbot.intellij.plugin.language.python.table.UtPyClassItem +import org.utbot.intellij.plugin.language.python.table.UtPyFunctionItem +import org.utbot.intellij.plugin.language.python.table.UtPyTableItem import org.utbot.python.utils.RequirementsUtils import kotlin.random.Random @@ -65,3 +68,13 @@ fun PsiDirectory.topParent(level: Int): PsiDirectory? { } return directory } + +fun PyElement.fileName(): String? = this.containingFile.virtualFile.canonicalPath + +fun PyElement.toUtPyTableItem(): UtPyTableItem? { + return when (this) { + is PyClass -> UtPyClassItem(this) + is PyFunction -> UtPyFunctionItem(this) + else -> null + } +} \ No newline at end of file diff --git a/utbot-ui-commons/src/main/kotlin/org/utbot/intellij/plugin/ui/components/TestSourceDirectoryChooser.kt b/utbot-ui-commons/src/main/kotlin/org/utbot/intellij/plugin/ui/components/TestSourceDirectoryChooser.kt index 12ced23966..8876815f6c 100644 --- a/utbot-ui-commons/src/main/kotlin/org/utbot/intellij/plugin/ui/components/TestSourceDirectoryChooser.kt +++ b/utbot-ui-commons/src/main/kotlin/org/utbot/intellij/plugin/ui/components/TestSourceDirectoryChooser.kt @@ -2,6 +2,7 @@ package org.utbot.intellij.plugin.ui.components import com.intellij.openapi.fileChooser.FileChooserDescriptor import com.intellij.openapi.project.Project +import com.intellij.openapi.project.guessProjectDir import com.intellij.openapi.roots.ProjectFileIndex import com.intellij.openapi.ui.TextBrowseFolderListener import com.intellij.openapi.ui.TextFieldWithBrowseButton @@ -13,9 +14,12 @@ import org.utbot.intellij.plugin.models.BaseTestsModel class TestSourceDirectoryChooser( val model: BaseTestsModel, - file: VirtualFile + file: VirtualFile? = null ) : TextFieldWithBrowseButton() { - private val projectRoot = getContentRoot(model.project, file) + private val projectRoot = file + ?.let { getContentRoot(model.project, file) } + ?: model.project.guessProjectDir() + ?: error("Source file lies outside of a module") init { val descriptor = FileChooserDescriptor(