Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,16 @@ struct CAFFE2_API PyObjectType : public Type {
: Type(TypeKind::PyObjectType) {}
};

enum class TypeVerbosity {
None,
Type,
TypeAndStride,
Full,
Default = Full,
};

CAFFE2_API TypeVerbosity type_verbosity();

CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t);
template <typename T>
CAFFE2_API std::ostream& operator<<(
Expand Down
30 changes: 20 additions & 10 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@

namespace c10 {

TypeVerbosity type_verbosity() {
static const char* c_verbosity = std::getenv("PYTORCH_JIT_TYPE_VERBOSITY");
static TypeVerbosity verbosity = c_verbosity ?
static_cast<TypeVerbosity>(c10::stoi(c_verbosity)) : TypeVerbosity::Default;
return verbosity;
}

std::ostream& operator<<(std::ostream & out, const Type & t) {
if (auto value = t.cast<TensorType>()) {
if (value->scalarType().has_value()) {
Expand All @@ -34,21 +41,24 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
} else {
out << "*";
}
if (has_valid_strides_info) {
if (has_valid_strides_info &&
type_verbosity() >= TypeVerbosity::TypeAndStride) {
out << ":" << *value->strides()[i];
}
}
if (value->requiresGrad()) {
if (i++ > 0) {
out << ", ";
if (type_verbosity() >= TypeVerbosity::Full) {
if (value->requiresGrad()) {
if (i++ > 0) {
out << ", ";
}
out << "requires_grad=" << *value->requiresGrad();
}
out << "requires_grad=" << *value->requiresGrad();
}
if (value->device()) {
if (i++ > 0) {
out << ", ";
if (value->device()) {
if (i++ > 0) {
out << ", ";
}
out << "device=" << *value->device();
}
out << "device=" << *value->device();
}
out << ")";
}
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/OVERVIEW.md
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,13 @@ one specifies a file(s) in `PYTORCH_JIT_LOG_LEVEL`.
`>>` and `>>>` are also valid and **currently** are equivalent to `GRAPH_DEBUG` as there is no logging level that is
higher than `GRAPH_DEBUG`.

By default, types in the graph are printed with maximum verbosity. The verbosity level can be controlled via the environment variable `PYTORCH_JIT_TYPE_VERBOSITY`. The available settings are:

* `0`: No type information
* `1`: Types and shapes only
* `2`: Also print strides
* `3`: Also print device type and whether gradient is required

## DifferentiableGraphOp ##

[runtime/graph_executor.cpp](runtime/graph_executor.cpp)
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/jit/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ std::ostream& operator<<(
out << l.delim;
}
printValueRef(out, n);
out << " : ";
out << *n->type();
if (c10::type_verbosity() >= c10::TypeVerbosity::Type) {
out << " : ";
out << *n->type();
}
}
return out;
}
Expand Down