|
18 | 18 | #include "arrow/c/bridge.h" |
19 | 19 |
|
20 | 20 | #include <algorithm> |
| 21 | +#include <cerrno> |
21 | 22 | #include <cstring> |
22 | 23 | #include <string> |
23 | 24 | #include <utility> |
@@ -1501,4 +1502,197 @@ Result<std::shared_ptr<RecordBatch>> ImportRecordBatch(struct ArrowArray* array, |
1501 | 1502 | return ImportRecordBatch(array, *maybe_schema); |
1502 | 1503 | } |
1503 | 1504 |
|
| 1505 | +////////////////////////////////////////////////////////////////////////// |
| 1506 | +// C stream export |
| 1507 | + |
| 1508 | +namespace { |
| 1509 | + |
| 1510 | +class ExportedArrayStream { |
| 1511 | + public: |
| 1512 | + struct PrivateData { |
| 1513 | + explicit PrivateData(std::shared_ptr<RecordBatchReader> reader) |
| 1514 | + : reader_(std::move(reader)) {} |
| 1515 | + |
| 1516 | + std::shared_ptr<RecordBatchReader> reader_; |
| 1517 | + std::string last_error_; |
| 1518 | + |
| 1519 | + PrivateData() = default; |
| 1520 | + ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData); |
| 1521 | + }; |
| 1522 | + |
| 1523 | + explicit ExportedArrayStream(struct ArrowArrayStream* stream) : stream_(stream) {} |
| 1524 | + |
| 1525 | + Status GetSchema(struct ArrowSchema* out_schema) { |
| 1526 | + return ExportSchema(*reader()->schema(), out_schema); |
| 1527 | + } |
| 1528 | + |
| 1529 | + Status GetNext(struct ArrowArray* out_array) { |
| 1530 | + std::shared_ptr<RecordBatch> batch; |
| 1531 | + RETURN_NOT_OK(reader()->ReadNext(&batch)); |
| 1532 | + if (batch == nullptr) { |
| 1533 | + // End of stream |
| 1534 | + ArrowArrayMarkReleased(out_array); |
| 1535 | + return Status::OK(); |
| 1536 | + } else { |
| 1537 | + return ExportRecordBatch(*batch, out_array); |
| 1538 | + } |
| 1539 | + } |
| 1540 | + |
| 1541 | + const char* GetLastError() { |
| 1542 | + const auto& last_error = private_data()->last_error_; |
| 1543 | + return last_error.empty() ? nullptr : last_error.c_str(); |
| 1544 | + } |
| 1545 | + |
| 1546 | + void Release() { |
| 1547 | + if (ArrowArrayStreamIsReleased(stream_)) { |
| 1548 | + return; |
| 1549 | + } |
| 1550 | + DCHECK_NE(private_data(), nullptr); |
| 1551 | + delete private_data(); |
| 1552 | + |
| 1553 | + ArrowArrayStreamMarkReleased(stream_); |
| 1554 | + } |
| 1555 | + |
| 1556 | + // C-compatible callbacks |
| 1557 | + |
| 1558 | + static int StaticGetSchema(struct ArrowArrayStream* stream, |
| 1559 | + struct ArrowSchema* out_schema) { |
| 1560 | + ExportedArrayStream self{stream}; |
| 1561 | + return self.ToCError(self.GetSchema(out_schema)); |
| 1562 | + } |
| 1563 | + |
| 1564 | + static int StaticGetNext(struct ArrowArrayStream* stream, |
| 1565 | + struct ArrowArray* out_array) { |
| 1566 | + ExportedArrayStream self{stream}; |
| 1567 | + return self.ToCError(self.GetNext(out_array)); |
| 1568 | + } |
| 1569 | + |
| 1570 | + static void StaticRelease(struct ArrowArrayStream* stream) { |
| 1571 | + ExportedArrayStream{stream}.Release(); |
| 1572 | + } |
| 1573 | + |
| 1574 | + static const char* StaticGetLastError(struct ArrowArrayStream* stream) { |
| 1575 | + return ExportedArrayStream{stream}.GetLastError(); |
| 1576 | + } |
| 1577 | + |
| 1578 | + private: |
| 1579 | + int ToCError(const Status& status) { |
| 1580 | + if (ARROW_PREDICT_TRUE(status.ok())) { |
| 1581 | + private_data()->last_error_.clear(); |
| 1582 | + return 0; |
| 1583 | + } |
| 1584 | + private_data()->last_error_ = status.ToString(); |
| 1585 | + switch (status.code()) { |
| 1586 | + case StatusCode::IOError: |
| 1587 | + return EIO; |
| 1588 | + case StatusCode::NotImplemented: |
| 1589 | + return ENOSYS; |
| 1590 | + case StatusCode::OutOfMemory: |
| 1591 | + return ENOMEM; |
| 1592 | + default: |
| 1593 | + return EINVAL; // Fallback for Invalid, TypeError, etc. |
| 1594 | + } |
| 1595 | + } |
| 1596 | + |
| 1597 | + PrivateData* private_data() { |
| 1598 | + return reinterpret_cast<PrivateData*>(stream_->private_data); |
| 1599 | + } |
| 1600 | + |
| 1601 | + const std::shared_ptr<RecordBatchReader>& reader() { return private_data()->reader_; } |
| 1602 | + |
| 1603 | + struct ArrowArrayStream* stream_; |
| 1604 | +}; |
| 1605 | + |
| 1606 | +} // namespace |
| 1607 | + |
| 1608 | +Status ExportRecordBatchReader(std::shared_ptr<RecordBatchReader> reader, |
| 1609 | + struct ArrowArrayStream* out) { |
| 1610 | + out->get_schema = ExportedArrayStream::StaticGetSchema; |
| 1611 | + out->get_next = ExportedArrayStream::StaticGetNext; |
| 1612 | + out->get_last_error = ExportedArrayStream::StaticGetLastError; |
| 1613 | + out->release = ExportedArrayStream::StaticRelease; |
| 1614 | + out->private_data = new ExportedArrayStream::PrivateData{std::move(reader)}; |
| 1615 | + return Status::OK(); |
| 1616 | +} |
| 1617 | + |
| 1618 | +////////////////////////////////////////////////////////////////////////// |
| 1619 | +// C stream import |
| 1620 | + |
| 1621 | +namespace { |
| 1622 | + |
| 1623 | +class ArrayStreamBatchReader : public RecordBatchReader { |
| 1624 | + public: |
| 1625 | + explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream) { |
| 1626 | + ArrowArrayStreamMove(stream, &stream_); |
| 1627 | + DCHECK(!ArrowArrayStreamIsReleased(&stream_)); |
| 1628 | + } |
| 1629 | + |
| 1630 | + ~ArrayStreamBatchReader() { |
| 1631 | + ArrowArrayStreamRelease(&stream_); |
| 1632 | + DCHECK(ArrowArrayStreamIsReleased(&stream_)); |
| 1633 | + } |
| 1634 | + |
| 1635 | + std::shared_ptr<Schema> schema() const override { return CacheSchema(); } |
| 1636 | + |
| 1637 | + Status ReadNext(std::shared_ptr<RecordBatch>* batch) override { |
| 1638 | + struct ArrowArray c_array; |
| 1639 | + RETURN_NOT_OK(StatusFromCError(stream_.get_next(&stream_, &c_array))); |
| 1640 | + if (ArrowArrayIsReleased(&c_array)) { |
| 1641 | + // End of stream |
| 1642 | + batch->reset(); |
| 1643 | + return Status::OK(); |
| 1644 | + } else { |
| 1645 | + return ImportRecordBatch(&c_array, CacheSchema()).Value(batch); |
| 1646 | + } |
| 1647 | + } |
| 1648 | + |
| 1649 | + private: |
| 1650 | + std::shared_ptr<Schema> CacheSchema() const { |
| 1651 | + if (!schema_) { |
| 1652 | + struct ArrowSchema c_schema; |
| 1653 | + ARROW_CHECK_OK(StatusFromCError(stream_.get_schema(&stream_, &c_schema))); |
| 1654 | + schema_ = ImportSchema(&c_schema).ValueOrDie(); |
| 1655 | + } |
| 1656 | + return schema_; |
| 1657 | + } |
| 1658 | + |
| 1659 | + Status StatusFromCError(int errno_like) const { |
| 1660 | + if (ARROW_PREDICT_TRUE(errno_like == 0)) { |
| 1661 | + return Status::OK(); |
| 1662 | + } |
| 1663 | + StatusCode code; |
| 1664 | + switch (errno_like) { |
| 1665 | + case EDOM: |
| 1666 | + case EINVAL: |
| 1667 | + case ERANGE: |
| 1668 | + code = StatusCode::Invalid; |
| 1669 | + break; |
| 1670 | + case ENOMEM: |
| 1671 | + code = StatusCode::OutOfMemory; |
| 1672 | + break; |
| 1673 | + case ENOSYS: |
| 1674 | + code = StatusCode::NotImplemented; |
| 1675 | + default: |
| 1676 | + code = StatusCode::IOError; |
| 1677 | + break; |
| 1678 | + } |
| 1679 | + const char* last_error = stream_.get_last_error(&stream_); |
| 1680 | + return Status(code, last_error ? std::string(last_error) : ""); |
| 1681 | + } |
| 1682 | + |
| 1683 | + mutable struct ArrowArrayStream stream_; |
| 1684 | + mutable std::shared_ptr<Schema> schema_; |
| 1685 | +}; |
| 1686 | + |
| 1687 | +} // namespace |
| 1688 | + |
| 1689 | +Result<std::shared_ptr<RecordBatchReader>> ImportRecordBatchReader( |
| 1690 | + struct ArrowArrayStream* stream) { |
| 1691 | + if (ArrowArrayStreamIsReleased(stream)) { |
| 1692 | + return Status::Invalid("Cannot import released ArrowArrayStream"); |
| 1693 | + } |
| 1694 | + // XXX should we call get_schema() here to avoid crashing on error? |
| 1695 | + return std::make_shared<ArrayStreamBatchReader>(stream); |
| 1696 | +} |
| 1697 | + |
1504 | 1698 | } // namespace arrow |
0 commit comments